From a85fe02edcffe33319edfd0435ef8d9876bd44d2 Mon Sep 17 00:00:00 2001 From: Tyler Kennedy Date: Sun, 13 Feb 2022 00:21:19 -0500 Subject: [PATCH] Cleanup tests. Use proper pytest fixtures for shared test data. Remove two unnecessary array comparison functions. --- tests/conftest.py | 244 +++++++++++++++++++++++++++++++++++++++++ tests/test_abstract.py | 99 ++++++++++------- tests/test_data.py | 240 ---------------------------------------- tests/test_func.py | 91 +++++++++------ tests/test_pandas.py | 42 +++---- tests/test_polars.py | 30 ++--- tests/test_stream.py | 5 + 7 files changed, 400 insertions(+), 351 deletions(-) create mode 100644 tests/conftest.py delete mode 100644 tests/test_data.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..9756ae7dd --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,244 @@ +from __future__ import print_function + +import pytest +import numpy as np + + +@pytest.fixture(scope='session') +def ford_2012_dates(): + return np.asarray([ + 20120103, 20120104, 20120105, 20120106, 20120109, + 20120110, 20120111, 20120112, 20120113, 20120117, 20120118, 20120119, + 20120120, 20120123, 20120124, 20120125, 20120126, 20120127, 20120130, + 20120131, 20120201, 20120202, 20120203, 20120206, 20120207, 20120208, + 20120209, 20120210, 20120213, 20120214, 20120215, 20120216, 20120217, + 20120221, 20120222, 20120223, 20120224, 20120227, 20120228, 20120229, + 20120301, 20120302, 20120305, 20120306, 20120307, 20120308, 20120309, + 20120312, 20120313, 20120314, 20120315, 20120316, 20120319, 20120320, + 20120321, 20120322, 20120323, 20120326, 20120327, 20120328, 20120329, + 20120330, 20120402, 20120403, 20120404, 20120405, 20120409, 20120410, + 20120411, 20120412, 20120413, 20120416, 20120417, 20120418, 20120419, + 20120420, 20120423, 20120424, 20120425, 20120426, 20120427, 20120430, + 20120501, 20120502, 20120503, 20120504, 20120507, 20120508, 20120509, + 20120510, 20120511, 20120514, 20120515, 20120516, 20120517, 20120518, + 20120521, 20120522, 20120523, 20120524, 20120525, 20120529, 20120530, + 20120531, 20120601, 20120604, 20120605, 20120606, 20120607, 20120608, + 20120611, 20120612, 20120613, 20120614, 20120615, 20120618, 20120619, + 20120620, 20120621, 20120622, 20120625, 20120626, 20120627, 20120628, + 20120629, 20120702, 20120703, 20120705, 20120706, 20120709, 20120710, + 20120711, 20120712, 20120713, 20120716, 20120717, 20120718, 20120719, + 20120720, 20120723, 20120724, 20120725, 20120726, 20120727, 20120730, + 20120731, 20120801, 20120802, 20120803, 20120806, 20120807, 20120808, + 20120809, 20120810, 20120813, 20120814, 20120815, 20120816, 20120817, + 20120820, 20120821, 20120822, 20120823, 20120824, 20120827, 20120828, + 20120829, 20120830, 20120831, 20120904, 20120905, 20120906, 20120907, + 20120910, 20120911, 20120912, 20120913, 20120914, 20120917, 20120918, + 20120919, 20120920, 20120921, 20120924, 20120925, 20120926, 20120927, + 20120928, 20121001, 20121002, 20121003, 20121004, 20121005, 20121008, + 20121009, 20121010, 20121011, 20121012, 20121015, 20121016, 20121017, + 20121018, 20121019, 20121022, 20121023, 20121024, 20121025, 20121026, + 20121031, 20121101, 20121102, 20121105, 20121106, 20121107, 20121108, + 20121109, 20121112, 20121113, 20121114, 20121115, 20121116, 20121119, + 20121120, 20121121, 20121123, 20121126, 20121127, 20121128, 20121129, + 20121130, 20121203, 20121204, 20121205, 20121206, 20121207, 20121210, + 20121211, 20121212, 20121213, 20121214, 20121217, 20121218, 20121219, + 20121220, 20121221, 20121224, 20121226, 20121227, 20121228, 20121231 + ]) + + +@pytest.fixture(scope='session') +def ford_2012(): + return { + 'open': np.asarray([ + 11.00, 11.15, 11.33, 11.74, 11.83, 12.00, 11.74, 12.16, 12.01, + 12.20, 12.03, 12.48, 12.55, 12.69, 12.56, 12.80, 13.03, 11.96, + 12.06, 12.47, 12.73, 12.40, 12.47, 12.85, 12.93, 12.91, 12.89, + 12.52, 12.74, 12.46, 12.47, 12.38, 12.84, 12.74, 12.49, 12.27, + 12.43, 12.11, 12.34, 12.28, 12.48, 12.74, 12.67, 12.23, 12.21, + 12.41, 12.53, 12.57, 12.48, 12.64, 12.90, 12.86, 12.52, 12.48, + 12.59, 12.48, 12.31, 12.45, 12.51, 12.35, 12.33, 12.55, 12.50, + 12.71, 12.46, 12.38, 12.26, 12.19, 11.99, 11.94, 11.98, 12.01, + 11.98, 11.81, 11.81, 11.71, 11.15, 11.61, 11.51, 11.71, 12.03, + 11.42, 11.25, 11.16, 11.13, 10.84, 10.53, 10.60, 10.48, 10.83, + 10.61, 10.41, 10.34, 10.23, 10.16, 10.08, 10.02, 10.25, 10.32, + 10.50, 10.61, 10.69, 10.73, 10.62, 10.33, 10.15, 10.01, 10.29, + 10.73, 10.48, 10.77, 10.47, 10.39, 10.27, 10.40, 10.35, 10.37, + 10.58, 10.65, 10.35, 10.13, 10.06, 10.05, 9.93, 9.95, 9.50, 9.53, + 9.67, 9.47, 9.46, 9.50, 9.33, 9.26, 9.16, 9.22, 9.28, 9.38, 9.45, + 9.28, 9.08, 9.17, 9.17, 9.05, 8.99, 9.04, 9.13, 9.29, 8.99, 9.02, + 9.13, 9.18, 9.25, 9.31, 9.30, 9.35, 9.45, 9.44, 9.50, 9.65, 9.58, + 9.65, 9.50, 9.45, 9.42, 9.51, 9.37, 9.33, 9.30, 9.39, 9.37, 9.45, + 9.66, 9.95, 10.08, 10.18, 10.25, 10.20, 10.41, 10.27, 10.30, 10.49, + 10.48, 10.53, 10.30, 10.35, 9.98, 10.13, 9.99, 9.89, 10.01, 9.82, + 10.06, 10.17, 10.06, 10.21, 10.12, 10.06, 10.14, 10.11, 10.26, + 10.31, 10.36, 10.42, 10.14, 10.02, 10.08, 10.42, 10.35, 10.70, + 11.19, 11.31, 11.15, 11.33, 11.25, 11.07, 10.76, 11.03, 10.89, + 11.02, 10.57, 10.58, 10.65, 10.85, 10.84, 10.98, 11.05, 11.10, + 11.05, 11.32, 11.52, 11.56, 11.40, 11.32, 11.26, 11.27, 11.41, + 11.51, 11.52, 11.46, 11.27, 11.16, 11.48, 11.79, 11.74, 11.55, + 11.67, 12.31, 12.79, 12.55, 12.88 + ]), + 'high': np.asarray([ + 11.25, 11.53, 11.63, 11.80, 11.95, 12.05, 12.18, 12.18, 12.08, + 12.26, 12.37, 12.72, 12.64, 12.84, 12.86, 12.98, 13.05, 12.53, + 12.44, 12.51, 12.75, 12.43, 12.84, 13.00, 12.97, 12.96, 12.90, + 12.66, 12.74, 12.58, 12.57, 12.77, 12.88, 12.76, 12.51, 12.44, + 12.46, 12.36, 12.35, 12.55, 12.77, 12.94, 12.68, 12.25, 12.30, + 12.55, 12.73, 12.59, 12.72, 12.90, 13.04, 12.90, 12.68, 12.61, + 12.67, 12.54, 12.37, 12.50, 12.61, 12.36, 12.52, 12.58, 12.65, + 12.95, 12.52, 12.58, 12.29, 12.28, 12.02, 12.13, 12.03, 12.05, + 12.00, 11.85, 11.88, 11.72, 11.40, 11.61, 11.75, 11.93, 12.04, + 11.47, 11.34, 11.17, 11.15, 10.87, 10.79, 10.64, 10.81, 10.86, + 10.83, 10.53, 10.34, 10.43, 10.25, 10.18, 10.23, 10.40, 10.45, + 10.62, 10.68, 10.88, 10.75, 10.68, 10.37, 10.18, 10.24, 10.58, + 10.78, 10.68, 10.80, 10.55, 10.49, 10.45, 10.42, 10.40, 10.64, + 10.74, 10.68, 10.40, 10.18, 10.08, 10.10, 10.09, 9.98, 9.60, 9.79, + 9.74, 9.52, 9.47, 9.55, 9.38, 9.28, 9.32, 9.32, 9.35, 9.52, 9.50, + 9.35, 9.21, 9.24, 9.20, 9.11, 9.10, 9.18, 9.28, 9.42, 9.03, 9.15, + 9.21, 9.39, 9.38, 9.46, 9.36, 9.42, 9.66, 9.54, 9.67, 9.66, 9.64, + 9.70, 9.56, 9.54, 9.52, 9.52, 9.44, 9.40, 9.34, 9.43, 9.47, 9.62, + 9.96, 10.23, 10.28, 10.25, 10.30, 10.38, 10.57, 10.42, 10.45, 10.66, + 10.52, 10.54, 10.40, 10.37, 10.12, 10.18, 10.00, 10.08, 10.05, + 10.02, 10.15, 10.28, 10.12, 10.25, 10.12, 10.26, 10.25, 10.25, + 10.32, 10.41, 10.57, 10.43, 10.24, 10.11, 10.29, 10.49, 10.42, + 11.17, 11.30, 11.38, 11.35, 11.59, 11.34, 11.23, 11.10, 11.16, + 11.10, 11.05, 10.80, 10.64, 10.90, 11.02, 11.00, 11.10, 11.14, + 11.27, 11.26, 11.53, 11.60, 11.70, 11.44, 11.40, 11.31, 11.50, + 11.53, 11.58, 11.56, 11.50, 11.27, 11.41, 11.68, 11.85, 11.80, + 11.86, 12.40, 12.79, 12.81, 12.88, 13.08 + ]), + 'low': np.asarray([ + 10.99, 11.07, 11.24, 11.52, 11.70, 11.63, 11.65, 11.89, 11.84, + 11.96, 12.00, 12.43, 12.45, 12.55, 12.46, 12.70, 12.66, 11.79, + 12.00, 12.20, 12.29, 12.20, 12.39, 12.71, 12.83, 12.80, 12.67, + 12.37, 12.51, 12.34, 12.33, 12.38, 12.71, 12.46, 12.22, 12.16, + 12.19, 11.99, 12.20, 12.25, 12.45, 12.68, 12.41, 12.00, 12.15, + 12.32, 12.48, 12.37, 12.40, 12.63, 12.83, 12.51, 12.48, 12.39, + 12.55, 12.24, 12.18, 12.39, 12.30, 12.18, 12.24, 12.40, 12.44, + 12.46, 12.32, 12.38, 12.11, 11.65, 11.88, 11.86, 11.84, 11.83, + 11.88, 11.72, 11.58, 11.39, 11.15, 11.36, 11.43, 11.67, 11.52, + 11.15, 11.11, 11.00, 10.85, 10.63, 10.52, 10.40, 10.41, 10.66, + 10.56, 10.30, 10.10, 10.15, 10.01, 9.96, 10.00, 10.15, 10.22, 10.38, + 10.51, 10.68, 10.52, 10.40, 10.06, 9.91, 9.97, 10.27, 10.52, 10.38, + 10.45, 10.31, 10.22, 10.21, 10.26, 10.26, 10.35, 10.52, 10.25, + 10.18, 9.95, 9.96, 9.97, 9.93, 9.46, 9.30, 9.49, 9.53, 9.40, 9.31, + 9.28, 9.26, 9.12, 9.14, 9.15, 9.12, 9.34, 9.33, 9.18, 9.05, 8.95, + 8.91, 8.83, 8.88, 9.01, 9.12, 8.99, 8.82, 8.96, 9.09, 9.18, 9.24, + 9.30, 9.23, 9.25, 9.42, 9.41, 9.49, 9.60, 9.51, 9.52, 9.40, 9.42, + 9.41, 9.38, 9.31, 9.29, 9.25, 9.31, 9.35, 9.39, 9.66, 9.93, 10.06, + 10.13, 10.17, 10.12, 10.39, 10.26, 10.28, 10.45, 10.35, 10.36, + 10.26, 10.06, 9.86, 10.02, 9.81, 9.88, 9.71, 9.76, 9.96, 10.13, + 9.99, 10.02, 9.95, 10.05, 10.09, 10.09, 10.22, 10.26, 10.33, 10.13, + 10.03, 9.97, 10.01, 10.28, 10.22, 10.60, 10.88, 11.15, 11.13, 11.26, + 11.04, 10.89, 10.71, 10.96, 10.86, 10.62, 10.46, 10.38, 10.65, + 10.76, 10.80, 10.96, 10.97, 11.10, 10.98, 11.32, 11.33, 11.40, + 11.23, 11.18, 11.19, 11.26, 11.41, 11.40, 11.43, 11.21, 11.03, + 11.14, 11.40, 11.62, 11.58, 11.47, 11.67, 12.31, 12.36, 12.52, 12.76 + ]), + 'close': np.asarray([ + 11.13, 11.30, 11.59, 11.71, 11.80, 11.80, 12.07, 12.14, 12.04, + 12.02, 12.34, 12.61, 12.59, 12.66, 12.82, 12.93, 12.79, 12.21, + 12.29, 12.42, 12.33, 12.26, 12.79, 12.96, 12.88, 12.84, 12.69, + 12.44, 12.54, 12.48, 12.38, 12.74, 12.75, 12.53, 12.28, 12.40, + 12.23, 12.30, 12.25, 12.38, 12.66, 12.72, 12.46, 12.09, 12.24, + 12.46, 12.58, 12.43, 12.70, 12.88, 12.90, 12.51, 12.63, 12.54, + 12.57, 12.32, 12.32, 12.48, 12.32, 12.32, 12.50, 12.48, 12.62, + 12.64, 12.51, 12.47, 12.22, 11.79, 11.91, 12.07, 11.92, 11.88, + 11.91, 11.79, 11.66, 11.41, 11.35, 11.39, 11.73, 11.87, 11.60, + 11.28, 11.23, 11.10, 10.92, 10.67, 10.66, 10.61, 10.69, 10.71, + 10.58, 10.32, 10.15, 10.16, 10.01, 10.01, 10.20, 10.19, 10.41, + 10.59, 10.60, 10.84, 10.66, 10.56, 10.12, 10.04, 10.19, 10.57, + 10.55, 10.66, 10.45, 10.50, 10.30, 10.41, 10.35, 10.34, 10.56, + 10.65, 10.27, 10.19, 10.01, 10.01, 10.02, 10.09, 9.59, 9.39, 9.60, + 9.57, 9.50, 9.45, 9.35, 9.33, 9.13, 9.27, 9.26, 9.34, 9.38, 9.35, + 9.21, 9.17, 9.06, 8.97, 8.96, 9.00, 9.10, 9.24, 9.04, 8.92, 9.09, + 9.15, 9.31, 9.35, 9.34, 9.35, 9.40, 9.44, 9.49, 9.59, 9.63, 9.63, + 9.53, 9.49, 9.45, 9.49, 9.39, 9.34, 9.32, 9.31, 9.34, 9.41, 9.57, + 9.92, 10.14, 10.11, 10.15, 10.21, 10.34, 10.53, 10.39, 10.42, 10.59, + 10.44, 10.40, 10.32, 10.09, 10.01, 10.02, 9.86, 9.93, 9.79, 9.94, + 10.11, 10.16, 10.05, 10.10, 9.98, 10.14, 10.12, 10.22, 10.30, 10.41, + 10.43, 10.18, 10.17, 10.00, 10.17, 10.39, 10.36, 11.16, 11.25, + 11.17, 11.25, 11.42, 11.06, 10.90, 10.93, 10.97, 11.00, 10.67, + 10.57, 10.50, 10.83, 10.85, 10.92, 11.10, 11.11, 11.10, 11.25, + 11.53, 11.45, 11.41, 11.31, 11.31, 11.24, 11.48, 11.47, 11.49, + 11.47, 11.27, 11.10, 11.39, 11.67, 11.73, 11.77, 11.86, 12.40, + 12.79, 12.76, 12.87, 12.95 + ]), + 'volume': np.asarray([ + 45709900, 79725200, 67877500, 59840700, 53981500, 121750600, + 63806000, 48687700, 46366700, 44398400, 47102700, 70894200, + 43705700, 49379700, 45768400, 54021600, 75470700, 142155300, + 57752600, 46412100, 71669000, 48347600, 78851200, 46363300, + 39413500, 35352500, 52290500, 52505500, 34474400, 39627900, + 38174800, 49164400, 30778000, 38409800, 43326000, 36747600, + 31399300, 38703400, 30789000, 62093700, 68262000, 49063500, + 28433700, 57374500, 28440900, 37099100, 36159300, 30275700, + 42783600, 47578500, 55286600, 77119600, 52445700, 40214400, + 27521400, 50117100, 44755000, 26692200, 35070700, 41051700, + 51039700, 36381000, 43966900, 97034200, 51505000, 37939500, + 42515300, 77370300, 34724400, 26988800, 39675000, 31903500, + 35981200, 32314000, 48169200, 52631000, 31269200, 38615200, + 45185400, 40889300, 83070300, 46156300, 43959200, 48572900, + 40238400, 53268400, 33235200, 46174500, 54501200, 42526100, + 36561300, 50225200, 41886500, 44321300, 49648900, 50572000, + 38134900, 44295700, 75647800, 45334100, 30430800, 43760600, + 44592100, 54297000, 68237000, 57305600, 38326200, 50458000, + 33846100, 30811600, 35811400, 35130800, 53471900, 37531800, + 39442000, 27361000, 37155900, 40810100, 40062800, 56427300, + 44297600, 31871900, 33278900, 38648400, 138138600, 63388600, + 49629300, 31783900, 30355400, 37441600, 33516600, 32028700, + 55111000, 30248300, 28838200, 29510000, 31010000, 33615000, + 27968300, 33773800, 53519200, 44338200, 51798900, 67986800, + 40958300, 41360900, 65973000, 45326500, 38631400, 23819100, + 43574500, 22630300, 30909800, 19618800, 21122000, 21129500, + 21308300, 34323700, 34533900, 38923800, 26281100, 26965500, + 23537700, 19574600, 22754200, 23084400, 26115700, 16459400, + 28029200, 37965000, 40608800, 67996400, 60617000, 43381300, + 28165300, 28046500, 50920200, 55934300, 31922200, 34937000, + 42403000, 28755100, 35459800, 28557900, 36866300, 44362600, + 25740900, 44586300, 33445600, 63630000, 51023800, 46855500, + 40693900, 25473900, 38235700, 33951600, 39328700, 24108500, + 26466500, 32788400, 29346300, 44041700, 40493000, 39149700, + 32476500, 49339800, 59290900, 43485500, 137960900, 88770100, + 53399000, 37995000, 51232200, 56674900, 45948800, 40703600, + 25723100, 33342900, 45664700, 48879800, 45346200, 39359100, + 34739800, 21181700, 16032200, 26831700, 37610000, 38496900, + 57289300, 41329600, 47746300, 37760200, 33152400, 31065800, + 38404500, 26025200, 36326900, 31099900, 35443200, 36933500, + 46983300, 61810400, 54884700, 47750100, 94489300, 91734900, + 140331900, 108315100, 95668600, 106908900 + ]) + } + + +@pytest.fixture(scope='session') +def series(): + return np.array([ + 91.50, 94.81, 94.38, 95.09, 93.78, 94.62, 92.53, 92.75, 90.31, 92.47, + 96.12, 97.25, 98.50, 89.88, 91.00, 92.81, 89.16, 89.34, 91.62, 89.88, + 88.38, 87.62, 84.78, 83.00, 83.50, 81.38, 84.44, 89.25, 86.38, 86.25, + 85.25, 87.12, 85.81, 88.97, 88.47, 86.88, 86.81, 84.88, 84.19, 83.88, + 83.38, 85.50, 89.19, 89.44, 91.09, 90.75, 91.44, 89.00, 91.00, 90.50, + 89.03, 88.81, 84.28, 83.50, 82.69, 84.75, 85.66, 86.19, 88.94, 89.28, + 88.62, 88.50, 91.97, 91.50, 93.25, 93.50, 93.16, 91.72, 90.00, 89.69, + 88.88, 85.19, 83.38, 84.88, 85.94, 97.25, 99.88, 104.94, 106.00, 102.50, + 102.41, 104.59, 106.12, 106.00, 106.06, 104.62, 108.62, 109.31, 110.50, + 112.75, 123.00, 119.62, 118.75, 119.25, 117.94, 116.44, 115.19, 111.88, + 110.59, 118.12, 116.00, 116.00, 112.00, 113.75, 112.94, 116.00, 120.50, + 116.62, 117.00, 115.25, 114.31, 115.50, 115.87, 120.69, 120.19, 120.75, + 124.75, 123.37, 122.94, 122.56, 123.12, 122.56, 124.62, 129.25, 131.00, + 132.25, 131.00, 132.81, 134.00, 137.38, 137.81, 137.88, 137.25, 136.31, + 136.25, 134.63, 128.25, 129.00, 123.87, 124.81, 123.00, 126.25, 128.38, + 125.37, 125.69, 122.25, 119.37, 118.50, 123.19, 123.50, 122.19, 119.31, + 123.31, 121.12, 123.37, 127.37, 128.50, 123.87, 122.94, 121.75, 124.44, + 122.00, 122.37, 122.94, 124.00, 123.19, 124.56, 127.25, 125.87, 128.86, + 132.00, 130.75, 134.75, 135.00, 132.38, 133.31, 131.94, 130.00, 125.37, + 130.13, 127.12, 125.19, 122.00, 125.00, 123.00, 123.50, 120.06, 121.00, + 117.75, 119.87, 122.00, 119.19, 116.37, 113.50, 114.25, 110.00, 105.06, + 107.00, 107.87, 107.00, 107.12, 107.00, 91.00, 93.94, 93.87, 95.50, + 93.00, 94.94, 98.25, 96.75, 94.81, 94.37, 91.56, 90.25, 93.94, 93.62, + 97.00, 95.00, 95.87, 94.06, 94.62, 93.75, 98.00, 103.94, 107.87, 106.06, + 104.50, 105.00, 104.19, 103.06, 103.42, 105.27, 111.87, 116.00, 116.62, + 118.28, 113.37, 109.00, 109.70, 109.25, 107.00, 109.19, 110.00, 109.20, + 110.12, 108.00, 108.62, 109.75, 109.81, 109.00, 108.75, 107.87 + ]) \ No newline at end of file diff --git a/tests/test_abstract.py b/tests/test_abstract.py index e98948f70..2018136ef 100644 --- a/tests/test_abstract.py +++ b/tests/test_abstract.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.testing import assert_array_equal, assert_raises import pytest import re @@ -10,10 +11,13 @@ import talib from talib import func from talib import abstract -from talib.test_data import ford_2012, assert_np_arrays_equal, assert_np_arrays_not_equal -def test_pandas(): +def assert_array_not_equal(x, y): + assert_raises(AssertionError, assert_array_equal, x, y) + + +def test_pandas(ford_2012): import pandas input_df = pandas.DataFrame(ford_2012) input_dict = dict((k, pandas.Series(v)) for k, v in ford_2012.items()) @@ -21,23 +25,23 @@ def test_pandas(): expected_k, expected_d = func.STOCH(ford_2012['high'], ford_2012['low'], ford_2012['close']) # 5, 3, 0, 3, 0 output = abstract.Function('stoch', input_df).outputs assert isinstance(output, pandas.DataFrame) - assert_np_arrays_equal(expected_k, output['slowk']) - assert_np_arrays_equal(expected_d, output['slowd']) + assert_array_equal(expected_k, output['slowk']) + assert_array_equal(expected_d, output['slowd']) output = abstract.Function('stoch', input_dict).outputs assert isinstance(output, list) - assert_np_arrays_equal(expected_k, output[0]) - assert_np_arrays_equal(expected_d, output[1]) + assert_array_equal(expected_k, output[0]) + assert_array_equal(expected_d, output[1]) expected = func.SMA(ford_2012['close'], 10) output = abstract.Function('sma', input_df, 10).outputs assert isinstance(output, pandas.Series) - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) output = abstract.Function('sma', input_dict, 10).outputs assert isinstance(output, np.ndarray) - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) -def test_pandas_series(): +def test_pandas_series(ford_2012): import pandas input_df = pandas.DataFrame(ford_2012) output = talib.SMA(input_df['close'], 10) @@ -60,50 +64,52 @@ def test_pandas_series(): pandas.testing.assert_series_equal(output, expected) -def test_SMA(): +def test_SMA(ford_2012): expected = func.SMA(ford_2012['close'], 10) - assert_np_arrays_equal(expected, abstract.Function('sma', ford_2012, 10).outputs) - assert_np_arrays_equal(expected, abstract.Function('sma')(ford_2012, 10, price='close')) - assert_np_arrays_equal(expected, abstract.Function('sma')(ford_2012, timeperiod=10)) + assert_array_equal(expected, abstract.Function('sma', ford_2012, 10).outputs) + assert_array_equal(expected, abstract.Function('sma')(ford_2012, 10, price='close')) + assert_array_equal(expected, abstract.Function('sma')(ford_2012, timeperiod=10)) expected = func.SMA(ford_2012['open'], 10) - assert_np_arrays_equal(expected, abstract.Function('sma', ford_2012, 10, price='open').outputs) - assert_np_arrays_equal(expected, abstract.Function('sma', price='low')(ford_2012, 10, price='open')) - assert_np_arrays_not_equal(expected, abstract.Function('sma', ford_2012, 10, price='open')(timeperiod=20)) - assert_np_arrays_not_equal(expected, abstract.Function('sma', ford_2012)(10, price='close')) - assert_np_arrays_not_equal(expected, abstract.Function('sma', 10)(ford_2012, price='high')) - assert_np_arrays_not_equal(expected, abstract.Function('sma', price='low')(ford_2012, 10)) + assert_array_equal(expected, abstract.Function('sma', ford_2012, 10, price='open').outputs) + assert_array_equal(expected, abstract.Function('sma', price='low')(ford_2012, 10, price='open')) + assert_array_not_equal(expected, abstract.Function('sma', ford_2012, 10, price='open')(timeperiod=20)) + assert_array_not_equal(expected, abstract.Function('sma', ford_2012)(10, price='close')) + assert_array_not_equal(expected, abstract.Function('sma', 10)(ford_2012, price='high')) + assert_array_not_equal(expected, abstract.Function('sma', price='low')(ford_2012, 10)) input_arrays = {'foobarbaz': ford_2012['open']} - assert_np_arrays_equal(expected, abstract.SMA(input_arrays, 10, price='foobarbaz')) + assert_array_equal(expected, abstract.SMA(input_arrays, 10, price='foobarbaz')) -def test_STOCH(): +def test_STOCH(ford_2012): # check defaults match expected_k, expected_d = func.STOCH(ford_2012['high'], ford_2012['low'], ford_2012['close']) # 5, 3, 0, 3, 0 got_k, got_d = abstract.Function('stoch', ford_2012).outputs - assert_np_arrays_equal(expected_k, got_k) - assert_np_arrays_equal(expected_d, got_d) + assert_array_equal(expected_k, got_k) + assert_array_equal(expected_d, got_d) expected_k, expected_d = func.STOCH(ford_2012['high'], ford_2012['low'], ford_2012['close']) got_k, got_d = abstract.Function('stoch', ford_2012)(5, 3, 0, 3, 0) - assert_np_arrays_equal(expected_k, got_k) - assert_np_arrays_equal(expected_d, got_d) + assert_array_equal(expected_k, got_k) + assert_array_equal(expected_d, got_d) expected_k, expected_d = func.STOCH(ford_2012['high'], ford_2012['low'], ford_2012['close'], 15) got_k, got_d = abstract.Function('stoch', ford_2012)(15, 5, 0, 5, 0) - assert_np_arrays_not_equal(expected_k, got_k) - assert_np_arrays_not_equal(expected_d, got_d) + assert_array_not_equal(expected_k, got_k) + assert_array_not_equal(expected_d, got_d) expected_k, expected_d = func.STOCH(ford_2012['high'], ford_2012['low'], ford_2012['close'], 15, 5, 1, 5, 1) got_k, got_d = abstract.Function('stoch', ford_2012)(15, 5, 1, 5, 1) - assert_np_arrays_equal(expected_k, got_k) - assert_np_arrays_equal(expected_d, got_d) + assert_array_equal(expected_k, got_k) + assert_array_equal(expected_d, got_d) -def test_doji_candle(): + +def test_doji_candle(ford_2012): expected = func.CDLDOJI(ford_2012['open'], ford_2012['high'], ford_2012['low'], ford_2012['close']) got = abstract.Function('CDLDOJI').run(ford_2012) - assert_np_arrays_equal(got, expected) + assert_array_equal(got, expected) + -def test_MAVP(): +def test_MAVP(ford_2012): mavp = abstract.MAVP with pytest.raises(Exception): mavp.set_input_arrays(ford_2012) @@ -113,6 +119,7 @@ def test_MAVP(): assert mavp.set_input_arrays(input_d) assert mavp.input_arrays == input_d + def test_info(): stochrsi = abstract.Function('STOCHRSI') stochrsi.input_names = {'price': 'open'} @@ -158,6 +165,7 @@ def test_info(): } assert expected == abstract.Function('BBANDS').info + def test_input_names(): expected = OrderedDict([('price', 'close')]) assert expected == abstract.Function('MAMA').input_names @@ -177,7 +185,8 @@ def test_input_names(): } assert obv.input_names == expected -def test_input_arrays(): + +def test_input_arrays(ford_2012): mama = abstract.Function('MAMA') # test default setting @@ -211,6 +220,7 @@ def test_input_arrays(): willr.input_names = {'prices': ['high', 'low', 'open']} assert willr.set_input_arrays(input_d) + def test_parameters(): stoch = abstract.Function('STOCH') expected = OrderedDict([ @@ -243,18 +253,20 @@ def test_parameters(): expected['slowd_matype'] = 1 assert expected == stoch.parameters + def test_lookback(): assert abstract.Function('SMA', 10).lookback == 9 stochrsi = abstract.Function('stochrsi', 20, 5, 3) assert stochrsi.lookback == 26 -def test_call_supports_same_signature_as_func_module(): + +def test_call_supports_same_signature_as_func_module(ford_2012): adx = abstract.Function('ADX') expected = func.ADX(ford_2012['open'], ford_2012['high'], ford_2012['low']) output = adx(ford_2012['open'], ford_2012['high'], ford_2012['low']) - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) expected_error = re.escape('Too many price arguments: expected 3 (high, low, close)') @@ -266,7 +278,8 @@ def test_call_supports_same_signature_as_func_module(): with pytest.raises(TypeError, match=expected_error): adx(ford_2012['open'], ford_2012['high']) -def test_parameter_type_checking(): + +def test_parameter_type_checking(ford_2012): sma = abstract.Function('SMA', timeperiod=10) expected_error = re.escape('Invalid parameter value for timeperiod (expected int, got float)') @@ -283,29 +296,31 @@ def test_parameter_type_checking(): with pytest.raises(TypeError, match=expected_error): sma.set_parameters(timeperiod=35.5) -def test_call_doesnt_cache_parameters(): + +def test_call_doesnt_cache_parameters(ford_2012): sma = abstract.Function('SMA', timeperiod=10) expected = func.SMA(ford_2012['open'], 20) output = sma(ford_2012, timeperiod=20, price='open') - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) expected = func.SMA(ford_2012['close'], 20) output = sma(ford_2012, timeperiod=20) - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) expected = func.SMA(ford_2012['close'], 10) output = sma(ford_2012) - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) -def test_call_without_arguments(): +def test_call_without_arguments(): with pytest.raises(TypeError, match='Not enough price arguments'): abstract.Function('SMA')() with pytest.raises(TypeError, match='Not enough price arguments'): abstract.Function('SMA')(10) + def test_call_first_exception(): inputs = {'close': np.array([np.nan, np.nan, np.nan])} @@ -316,4 +331,4 @@ def test_call_first_exception(): output = abstract.SMA(inputs, timeperiod=2) expected = np.array([np.nan, 1.5, 2.5]) - assert_np_arrays_equal(expected, output) + assert_array_equal(expected, output) diff --git a/tests/test_data.py b/tests/test_data.py deleted file mode 100644 index d2ba123d7..000000000 --- a/tests/test_data.py +++ /dev/null @@ -1,240 +0,0 @@ - -from __future__ import print_function - -import numpy as np - -ford_2012_dates = np.asarray([ 20120103, 20120104, 20120105, 20120106, 20120109, - 20120110, 20120111, 20120112, 20120113, 20120117, 20120118, 20120119, - 20120120, 20120123, 20120124, 20120125, 20120126, 20120127, 20120130, - 20120131, 20120201, 20120202, 20120203, 20120206, 20120207, 20120208, - 20120209, 20120210, 20120213, 20120214, 20120215, 20120216, 20120217, - 20120221, 20120222, 20120223, 20120224, 20120227, 20120228, 20120229, - 20120301, 20120302, 20120305, 20120306, 20120307, 20120308, 20120309, - 20120312, 20120313, 20120314, 20120315, 20120316, 20120319, 20120320, - 20120321, 20120322, 20120323, 20120326, 20120327, 20120328, 20120329, - 20120330, 20120402, 20120403, 20120404, 20120405, 20120409, 20120410, - 20120411, 20120412, 20120413, 20120416, 20120417, 20120418, 20120419, - 20120420, 20120423, 20120424, 20120425, 20120426, 20120427, 20120430, - 20120501, 20120502, 20120503, 20120504, 20120507, 20120508, 20120509, - 20120510, 20120511, 20120514, 20120515, 20120516, 20120517, 20120518, - 20120521, 20120522, 20120523, 20120524, 20120525, 20120529, 20120530, - 20120531, 20120601, 20120604, 20120605, 20120606, 20120607, 20120608, - 20120611, 20120612, 20120613, 20120614, 20120615, 20120618, 20120619, - 20120620, 20120621, 20120622, 20120625, 20120626, 20120627, 20120628, - 20120629, 20120702, 20120703, 20120705, 20120706, 20120709, 20120710, - 20120711, 20120712, 20120713, 20120716, 20120717, 20120718, 20120719, - 20120720, 20120723, 20120724, 20120725, 20120726, 20120727, 20120730, - 20120731, 20120801, 20120802, 20120803, 20120806, 20120807, 20120808, - 20120809, 20120810, 20120813, 20120814, 20120815, 20120816, 20120817, - 20120820, 20120821, 20120822, 20120823, 20120824, 20120827, 20120828, - 20120829, 20120830, 20120831, 20120904, 20120905, 20120906, 20120907, - 20120910, 20120911, 20120912, 20120913, 20120914, 20120917, 20120918, - 20120919, 20120920, 20120921, 20120924, 20120925, 20120926, 20120927, - 20120928, 20121001, 20121002, 20121003, 20121004, 20121005, 20121008, - 20121009, 20121010, 20121011, 20121012, 20121015, 20121016, 20121017, - 20121018, 20121019, 20121022, 20121023, 20121024, 20121025, 20121026, - 20121031, 20121101, 20121102, 20121105, 20121106, 20121107, 20121108, - 20121109, 20121112, 20121113, 20121114, 20121115, 20121116, 20121119, - 20121120, 20121121, 20121123, 20121126, 20121127, 20121128, 20121129, - 20121130, 20121203, 20121204, 20121205, 20121206, 20121207, 20121210, - 20121211, 20121212, 20121213, 20121214, 20121217, 20121218, 20121219, - 20121220, 20121221, 20121224, 20121226, 20121227, 20121228, 20121231 ]) - -ford_2012 = { - 'open': np.asarray([ 11.00, 11.15, 11.33, 11.74, 11.83, 12.00, 11.74, 12.16, - 12.01, 12.20, 12.03, 12.48, 12.55, 12.69, 12.56, 12.80, 13.03, 11.96, - 12.06, 12.47, 12.73, 12.40, 12.47, 12.85, 12.93, 12.91, 12.89, 12.52, - 12.74, 12.46, 12.47, 12.38, 12.84, 12.74, 12.49, 12.27, 12.43, 12.11, - 12.34, 12.28, 12.48, 12.74, 12.67, 12.23, 12.21, 12.41, 12.53, 12.57, - 12.48, 12.64, 12.90, 12.86, 12.52, 12.48, 12.59, 12.48, 12.31, 12.45, - 12.51, 12.35, 12.33, 12.55, 12.50, 12.71, 12.46, 12.38, 12.26, 12.19, - 11.99, 11.94, 11.98, 12.01, 11.98, 11.81, 11.81, 11.71, 11.15, 11.61, - 11.51, 11.71, 12.03, 11.42, 11.25, 11.16, 11.13, 10.84, 10.53, 10.60, - 10.48, 10.83, 10.61, 10.41, 10.34, 10.23, 10.16, 10.08, 10.02, 10.25, - 10.32, 10.50, 10.61, 10.69, 10.73, 10.62, 10.33, 10.15, 10.01, 10.29, - 10.73, 10.48, 10.77, 10.47, 10.39, 10.27, 10.40, 10.35, 10.37, 10.58, - 10.65, 10.35, 10.13, 10.06, 10.05, 9.93, 9.95, 9.50, 9.53, 9.67, 9.47, - 9.46, 9.50, 9.33, 9.26, 9.16, 9.22, 9.28, 9.38, 9.45, 9.28, 9.08, 9.17, - 9.17, 9.05, 8.99, 9.04, 9.13, 9.29, 8.99, 9.02, 9.13, 9.18, 9.25, 9.31, - 9.30, 9.35, 9.45, 9.44, 9.50, 9.65, 9.58, 9.65, 9.50, 9.45, 9.42, 9.51, - 9.37, 9.33, 9.30, 9.39, 9.37, 9.45, 9.66, 9.95, 10.08, 10.18, 10.25, - 10.20, 10.41, 10.27, 10.30, 10.49, 10.48, 10.53, 10.30, 10.35, 9.98, - 10.13, 9.99, 9.89, 10.01, 9.82, 10.06, 10.17, 10.06, 10.21, 10.12, - 10.06, 10.14, 10.11, 10.26, 10.31, 10.36, 10.42, 10.14, 10.02, 10.08, - 10.42, 10.35, 10.70, 11.19, 11.31, 11.15, 11.33, 11.25, 11.07, 10.76, - 11.03, 10.89, 11.02, 10.57, 10.58, 10.65, 10.85, 10.84, 10.98, 11.05, - 11.10, 11.05, 11.32, 11.52, 11.56, 11.40, 11.32, 11.26, 11.27, 11.41, - 11.51, 11.52, 11.46, 11.27, 11.16, 11.48, 11.79, 11.74, 11.55, 11.67, - 12.31, 12.79, 12.55, 12.88, ]), - - 'high': np.asarray([ 11.25, 11.53, 11.63, 11.80, 11.95, 12.05, 12.18, 12.18, - 12.08, 12.26, 12.37, 12.72, 12.64, 12.84, 12.86, 12.98, 13.05, 12.53, - 12.44, 12.51, 12.75, 12.43, 12.84, 13.00, 12.97, 12.96, 12.90, 12.66, - 12.74, 12.58, 12.57, 12.77, 12.88, 12.76, 12.51, 12.44, 12.46, 12.36, - 12.35, 12.55, 12.77, 12.94, 12.68, 12.25, 12.30, 12.55, 12.73, 12.59, - 12.72, 12.90, 13.04, 12.90, 12.68, 12.61, 12.67, 12.54, 12.37, 12.50, - 12.61, 12.36, 12.52, 12.58, 12.65, 12.95, 12.52, 12.58, 12.29, 12.28, - 12.02, 12.13, 12.03, 12.05, 12.00, 11.85, 11.88, 11.72, 11.40, 11.61, - 11.75, 11.93, 12.04, 11.47, 11.34, 11.17, 11.15, 10.87, 10.79, 10.64, - 10.81, 10.86, 10.83, 10.53, 10.34, 10.43, 10.25, 10.18, 10.23, 10.40, - 10.45, 10.62, 10.68, 10.88, 10.75, 10.68, 10.37, 10.18, 10.24, 10.58, - 10.78, 10.68, 10.80, 10.55, 10.49, 10.45, 10.42, 10.40, 10.64, 10.74, - 10.68, 10.40, 10.18, 10.08, 10.10, 10.09, 9.98, 9.60, 9.79, 9.74, 9.52, - 9.47, 9.55, 9.38, 9.28, 9.32, 9.32, 9.35, 9.52, 9.50, 9.35, 9.21, 9.24, - 9.20, 9.11, 9.10, 9.18, 9.28, 9.42, 9.03, 9.15, 9.21, 9.39, 9.38, 9.46, - 9.36, 9.42, 9.66, 9.54, 9.67, 9.66, 9.64, 9.70, 9.56, 9.54, 9.52, 9.52, - 9.44, 9.40, 9.34, 9.43, 9.47, 9.62, 9.96, 10.23, 10.28, 10.25, 10.30, - 10.38, 10.57, 10.42, 10.45, 10.66, 10.52, 10.54, 10.40, 10.37, 10.12, - 10.18, 10.00, 10.08, 10.05, 10.02, 10.15, 10.28, 10.12, 10.25, 10.12, - 10.26, 10.25, 10.25, 10.32, 10.41, 10.57, 10.43, 10.24, 10.11, 10.29, - 10.49, 10.42, 11.17, 11.30, 11.38, 11.35, 11.59, 11.34, 11.23, 11.10, - 11.16, 11.10, 11.05, 10.80, 10.64, 10.90, 11.02, 11.00, 11.10, 11.14, - 11.27, 11.26, 11.53, 11.60, 11.70, 11.44, 11.40, 11.31, 11.50, 11.53, - 11.58, 11.56, 11.50, 11.27, 11.41, 11.68, 11.85, 11.80, 11.86, 12.40, - 12.79, 12.81, 12.88, 13.08, ]), - - 'low': np.asarray([ 10.99, 11.07, 11.24, 11.52, 11.70, 11.63, 11.65, 11.89, - 11.84, 11.96, 12.00, 12.43, 12.45, 12.55, 12.46, 12.70, 12.66, 11.79, - 12.00, 12.20, 12.29, 12.20, 12.39, 12.71, 12.83, 12.80, 12.67, 12.37, - 12.51, 12.34, 12.33, 12.38, 12.71, 12.46, 12.22, 12.16, 12.19, 11.99, - 12.20, 12.25, 12.45, 12.68, 12.41, 12.00, 12.15, 12.32, 12.48, 12.37, - 12.40, 12.63, 12.83, 12.51, 12.48, 12.39, 12.55, 12.24, 12.18, 12.39, - 12.30, 12.18, 12.24, 12.40, 12.44, 12.46, 12.32, 12.38, 12.11, 11.65, - 11.88, 11.86, 11.84, 11.83, 11.88, 11.72, 11.58, 11.39, 11.15, 11.36, - 11.43, 11.67, 11.52, 11.15, 11.11, 11.00, 10.85, 10.63, 10.52, 10.40, - 10.41, 10.66, 10.56, 10.30, 10.10, 10.15, 10.01, 9.96, 10.00, 10.15, - 10.22, 10.38, 10.51, 10.68, 10.52, 10.40, 10.06, 9.91, 9.97, 10.27, - 10.52, 10.38, 10.45, 10.31, 10.22, 10.21, 10.26, 10.26, 10.35, 10.52, - 10.25, 10.18, 9.95, 9.96, 9.97, 9.93, 9.46, 9.30, 9.49, 9.53, 9.40, - 9.31, 9.28, 9.26, 9.12, 9.14, 9.15, 9.12, 9.34, 9.33, 9.18, 9.05, 8.95, - 8.91, 8.83, 8.88, 9.01, 9.12, 8.99, 8.82, 8.96, 9.09, 9.18, 9.24, 9.30, - 9.23, 9.25, 9.42, 9.41, 9.49, 9.60, 9.51, 9.52, 9.40, 9.42, 9.41, 9.38, - 9.31, 9.29, 9.25, 9.31, 9.35, 9.39, 9.66, 9.93, 10.06, 10.13, 10.17, - 10.12, 10.39, 10.26, 10.28, 10.45, 10.35, 10.36, 10.26, 10.06, 9.86, - 10.02, 9.81, 9.88, 9.71, 9.76, 9.96, 10.13, 9.99, 10.02, 9.95, 10.05, - 10.09, 10.09, 10.22, 10.26, 10.33, 10.13, 10.03, 9.97, 10.01, 10.28, - 10.22, 10.60, 10.88, 11.15, 11.13, 11.26, 11.04, 10.89, 10.71, 10.96, - 10.86, 10.62, 10.46, 10.38, 10.65, 10.76, 10.80, 10.96, 10.97, 11.10, - 10.98, 11.32, 11.33, 11.40, 11.23, 11.18, 11.19, 11.26, 11.41, 11.40, - 11.43, 11.21, 11.03, 11.14, 11.40, 11.62, 11.58, 11.47, 11.67, 12.31, - 12.36, 12.52, 12.76, ]), - - 'close': np.asarray([ 11.13, 11.30, 11.59, 11.71, 11.80, 11.80, 12.07, 12.14, - 12.04, 12.02, 12.34, 12.61, 12.59, 12.66, 12.82, 12.93, 12.79, 12.21, - 12.29, 12.42, 12.33, 12.26, 12.79, 12.96, 12.88, 12.84, 12.69, 12.44, - 12.54, 12.48, 12.38, 12.74, 12.75, 12.53, 12.28, 12.40, 12.23, 12.30, - 12.25, 12.38, 12.66, 12.72, 12.46, 12.09, 12.24, 12.46, 12.58, 12.43, - 12.70, 12.88, 12.90, 12.51, 12.63, 12.54, 12.57, 12.32, 12.32, 12.48, - 12.32, 12.32, 12.50, 12.48, 12.62, 12.64, 12.51, 12.47, 12.22, 11.79, - 11.91, 12.07, 11.92, 11.88, 11.91, 11.79, 11.66, 11.41, 11.35, 11.39, - 11.73, 11.87, 11.60, 11.28, 11.23, 11.10, 10.92, 10.67, 10.66, 10.61, - 10.69, 10.71, 10.58, 10.32, 10.15, 10.16, 10.01, 10.01, 10.20, 10.19, - 10.41, 10.59, 10.60, 10.84, 10.66, 10.56, 10.12, 10.04, 10.19, 10.57, - 10.55, 10.66, 10.45, 10.50, 10.30, 10.41, 10.35, 10.34, 10.56, 10.65, - 10.27, 10.19, 10.01, 10.01, 10.02, 10.09, 9.59, 9.39, 9.60, 9.57, 9.50, - 9.45, 9.35, 9.33, 9.13, 9.27, 9.26, 9.34, 9.38, 9.35, 9.21, 9.17, 9.06, - 8.97, 8.96, 9.00, 9.10, 9.24, 9.04, 8.92, 9.09, 9.15, 9.31, 9.35, 9.34, - 9.35, 9.40, 9.44, 9.49, 9.59, 9.63, 9.63, 9.53, 9.49, 9.45, 9.49, 9.39, - 9.34, 9.32, 9.31, 9.34, 9.41, 9.57, 9.92, 10.14, 10.11, 10.15, 10.21, - 10.34, 10.53, 10.39, 10.42, 10.59, 10.44, 10.40, 10.32, 10.09, 10.01, - 10.02, 9.86, 9.93, 9.79, 9.94, 10.11, 10.16, 10.05, 10.10, 9.98, 10.14, - 10.12, 10.22, 10.30, 10.41, 10.43, 10.18, 10.17, 10.00, 10.17, 10.39, - 10.36, 11.16, 11.25, 11.17, 11.25, 11.42, 11.06, 10.90, 10.93, 10.97, - 11.00, 10.67, 10.57, 10.50, 10.83, 10.85, 10.92, 11.10, 11.11, 11.10, - 11.25, 11.53, 11.45, 11.41, 11.31, 11.31, 11.24, 11.48, 11.47, 11.49, - 11.47, 11.27, 11.10, 11.39, 11.67, 11.73, 11.77, 11.86, 12.40, 12.79, - 12.76, 12.87, 12.95, ]), - - 'volume': np.asarray([ 45709900, 79725200, 67877500, 59840700, 53981500, - 121750600, 63806000, 48687700, 46366700, 44398400, 47102700, 70894200, - 43705700, 49379700, 45768400, 54021600, 75470700, 142155300, 57752600, - 46412100, 71669000, 48347600, 78851200, 46363300, 39413500, 35352500, - 52290500, 52505500, 34474400, 39627900, 38174800, 49164400, 30778000, - 38409800, 43326000, 36747600, 31399300, 38703400, 30789000, 62093700, - 68262000, 49063500, 28433700, 57374500, 28440900, 37099100, 36159300, - 30275700, 42783600, 47578500, 55286600, 77119600, 52445700, 40214400, - 27521400, 50117100, 44755000, 26692200, 35070700, 41051700, 51039700, - 36381000, 43966900, 97034200, 51505000, 37939500, 42515300, 77370300, - 34724400, 26988800, 39675000, 31903500, 35981200, 32314000, 48169200, - 52631000, 31269200, 38615200, 45185400, 40889300, 83070300, 46156300, - 43959200, 48572900, 40238400, 53268400, 33235200, 46174500, 54501200, - 42526100, 36561300, 50225200, 41886500, 44321300, 49648900, 50572000, - 38134900, 44295700, 75647800, 45334100, 30430800, 43760600, 44592100, - 54297000, 68237000, 57305600, 38326200, 50458000, 33846100, 30811600, - 35811400, 35130800, 53471900, 37531800, 39442000, 27361000, 37155900, - 40810100, 40062800, 56427300, 44297600, 31871900, 33278900, 38648400, - 138138600, 63388600, 49629300, 31783900, 30355400, 37441600, 33516600, - 32028700, 55111000, 30248300, 28838200, 29510000, 31010000, 33615000, - 27968300, 33773800, 53519200, 44338200, 51798900, 67986800, 40958300, - 41360900, 65973000, 45326500, 38631400, 23819100, 43574500, 22630300, - 30909800, 19618800, 21122000, 21129500, 21308300, 34323700, 34533900, - 38923800, 26281100, 26965500, 23537700, 19574600, 22754200, 23084400, - 26115700, 16459400, 28029200, 37965000, 40608800, 67996400, 60617000, - 43381300, 28165300, 28046500, 50920200, 55934300, 31922200, 34937000, - 42403000, 28755100, 35459800, 28557900, 36866300, 44362600, 25740900, - 44586300, 33445600, 63630000, 51023800, 46855500, 40693900, 25473900, - 38235700, 33951600, 39328700, 24108500, 26466500, 32788400, 29346300, - 44041700, 40493000, 39149700, 32476500, 49339800, 59290900, 43485500, - 137960900, 88770100, 53399000, 37995000, 51232200, 56674900, 45948800, - 40703600, 25723100, 33342900, 45664700, 48879800, 45346200, 39359100, - 34739800, 21181700, 16032200, 26831700, 37610000, 38496900, 57289300, - 41329600, 47746300, 37760200, 33152400, 31065800, 38404500, 26025200, - 36326900, 31099900, 35443200, 36933500, 46983300, 61810400, 54884700, - 47750100, 94489300, 91734900, 140331900, 108315100, 95668600, 106908900 ]), - } - -series = np.array([ 91.50, 94.81, 94.38, 95.09, 93.78, 94.62, 92.53, 92.75, - 90.31, 92.47, 96.12, 97.25, 98.50, 89.88, 91.00, 92.81, 89.16, 89.34, - 91.62, 89.88, 88.38, 87.62, 84.78, 83.00, 83.50, 81.38, 84.44, 89.25, - 86.38, 86.25, 85.25, 87.12, 85.81, 88.97, 88.47, 86.88, 86.81, 84.88, - 84.19, 83.88, 83.38, 85.50, 89.19, 89.44, 91.09, 90.75, 91.44, 89.00, - 91.00, 90.50, 89.03, 88.81, 84.28, 83.50, 82.69, 84.75, 85.66, 86.19, - 88.94, 89.28, 88.62, 88.50, 91.97, 91.50, 93.25, 93.50, 93.16, 91.72, - 90.00, 89.69, 88.88, 85.19, 83.38, 84.88, 85.94, 97.25, 99.88, 104.94, - 106.00, 102.50, 102.41, 104.59, 106.12, 106.00, 106.06, 104.62, 108.62, - 109.31, 110.50, 112.75, 123.00, 119.62, 118.75, 119.25, 117.94, 116.44, - 115.19, 111.88, 110.59, 118.12, 116.00, 116.00, 112.00, 113.75, 112.94, - 116.00, 120.50, 116.62, 117.00, 115.25, 114.31, 115.50, 115.87, 120.69, - 120.19, 120.75, 124.75, 123.37, 122.94, 122.56, 123.12, 122.56, 124.62, - 129.25, 131.00, 132.25, 131.00, 132.81, 134.00, 137.38, 137.81, 137.88, - 137.25, 136.31, 136.25, 134.63, 128.25, 129.00, 123.87, 124.81, 123.00, - 126.25, 128.38, 125.37, 125.69, 122.25, 119.37, 118.50, 123.19, 123.50, - 122.19, 119.31, 123.31, 121.12, 123.37, 127.37, 128.50, 123.87, 122.94, - 121.75, 124.44, 122.00, 122.37, 122.94, 124.00, 123.19, 124.56, 127.25, - 125.87, 128.86, 132.00, 130.75, 134.75, 135.00, 132.38, 133.31, 131.94, - 130.00, 125.37, 130.13, 127.12, 125.19, 122.00, 125.00, 123.00, 123.50, - 120.06, 121.00, 117.75, 119.87, 122.00, 119.19, 116.37, 113.50, 114.25, - 110.00, 105.06, 107.00, 107.87, 107.00, 107.12, 107.00, 91.00, 93.94, - 93.87, 95.50, 93.00, 94.94, 98.25, 96.75, 94.81, 94.37, 91.56, 90.25, - 93.94, 93.62, 97.00, 95.00, 95.87, 94.06, 94.62, 93.75, 98.00, 103.94, - 107.87, 106.06, 104.50, 105.00, 104.19, 103.06, 103.42, 105.27, 111.87, - 116.00, 116.62, 118.28, 113.37, 109.00, 109.70, 109.25, 107.00, 109.19, - 110.00, 109.20, 110.12, 108.00, 108.62, 109.75, 109.81, 109.00, 108.75, - 107.87 ]) - -def assert_np_arrays_equal(expected, got): - for i, value in enumerate(expected): - if np.isnan(value): - assert np.isnan(got[i]) - else: - assert value == got[i] - -def assert_np_arrays_not_equal(expected, got): - ''' Verifies expected and got have the same number of leading nan fields, - followed by different floats. - ''' - nans = [] - equals = [] - for i, value in enumerate(expected): - if np.isnan(value): - assert np.isnan(got[i]) - nans.append(value) - else: - try: - assert value != got[i] - except AssertionError: - equals.append(got[i]) - if len(equals) == len(expected[len(nans):]): - raise AssertionError('Arrays were equal.') - elif equals: - print('Arrays had %i/%i equivalent values.' % (len(equals), len(expected[len(nans):]))) diff --git a/tests/test_func.py b/tests/test_func.py index 8b2a40a4e..643f1b504 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -1,21 +1,25 @@ import numpy as np +from numpy.testing import assert_array_equal import pytest import talib from talib import func -from talib.test_data import series, assert_np_arrays_equal, assert_np_arrays_not_equal + def test_talib_version(): assert talib.__ta_version__[:5] == b'0.4.0' + def test_num_functions(): assert len(talib.get_functions()) == 158 + def test_input_wrong_type(): a1 = np.arange(10, dtype=int) with pytest.raises(Exception): func.MOM(a1) + def test_input_lengths(): a1 = np.arange(10, dtype=float) a2 = np.arange(11, dtype=float) @@ -28,38 +32,42 @@ def test_input_lengths(): with pytest.raises(Exception): func.BOP(a1, a1, a1, a2) + def test_input_nans(): a1 = np.arange(10, dtype=float) a2 = np.arange(10, dtype=float) a2[0] = np.nan a2[1] = np.nan r1, r2 = func.AROON(a1, a2, 2) - assert_np_arrays_equal(r1, [np.nan, np.nan, np.nan, np.nan, 0, 0, 0, 0, 0, 0]) - assert_np_arrays_equal(r2, [np.nan, np.nan, np.nan, np.nan, 100, 100, 100, 100, 100, 100]) + assert_array_equal(r1, [np.nan, np.nan, np.nan, np.nan, 0, 0, 0, 0, 0, 0]) + assert_array_equal(r2, [np.nan, np.nan, np.nan, np.nan, 100, 100, 100, 100, 100, 100]) r1, r2 = func.AROON(a2, a1, 2) - assert_np_arrays_equal(r1, [np.nan, np.nan, np.nan, np.nan, 0, 0, 0, 0, 0, 0]) - assert_np_arrays_equal(r2, [np.nan, np.nan, np.nan, np.nan, 100, 100, 100, 100, 100, 100]) + assert_array_equal(r1, [np.nan, np.nan, np.nan, np.nan, 0, 0, 0, 0, 0, 0]) + assert_array_equal(r2, [np.nan, np.nan, np.nan, np.nan, 100, 100, 100, 100, 100, 100]) + def test_unstable_period(): a = np.arange(10, dtype=float) r = func.EMA(a, 3) - assert_np_arrays_equal(r, [np.nan, np.nan, 1, 2, 3, 4, 5, 6, 7, 8]) + assert_array_equal(r, [np.nan, np.nan, 1, 2, 3, 4, 5, 6, 7, 8]) talib.set_unstable_period('EMA', 5) r = func.EMA(a, 3) - assert_np_arrays_equal(r, [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 6, 7, 8]) + assert_array_equal(r, [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 6, 7, 8]) talib.set_unstable_period('EMA', 0) + def test_compatibility(): a = np.arange(10, dtype=float) talib.set_compatibility(0) r = func.EMA(a, 3) - assert_np_arrays_equal(r, [np.nan, np.nan, 1, 2, 3, 4, 5, 6, 7, 8]) + assert_array_equal(r, [np.nan, np.nan, 1, 2, 3, 4, 5, 6, 7, 8]) talib.set_compatibility(1) r = func.EMA(a, 3) - assert_np_arrays_equal(r, [np.nan, np.nan,1.25,2.125,3.0625,4.03125,5.015625,6.0078125,7.00390625,8.001953125]) + assert_array_equal(r, [np.nan, np.nan,1.25,2.125,3.0625,4.03125,5.015625,6.0078125,7.00390625,8.001953125]) talib.set_compatibility(0) -def test_MIN(): + +def test_MIN(series): result = func.MIN(series, timeperiod=4) i = np.where(~np.isnan(result))[0][0] assert len(series) == len(result) @@ -69,9 +77,10 @@ def test_MIN(): assert result[i + 4] == 92.530 values = np.array([np.nan, 5., 4., 3., 5., 7.]) result = func.MIN(values, timeperiod=2) - assert_np_arrays_equal(result, [np.nan, np.nan, 4, 3, 3, 5]) + assert_array_equal(result, [np.nan, np.nan, 4, 3, 3, 5]) + -def test_MAX(): +def test_MAX(series): result = func.MAX(series, timeperiod=4) i = np.where(~np.isnan(result))[0][0] assert len(series) == len(result) @@ -80,31 +89,38 @@ def test_MAX(): assert result[i + 4] == 94.620 assert result[i + 5] == 94.620 + def test_MOM(): values = np.array([90.0,88.0,89.0]) result = func.MOM(values, timeperiod=1) - assert_np_arrays_equal(result, [np.nan, -2, 1]) + assert_array_equal(result, [np.nan, -2, 1]) result = func.MOM(values, timeperiod=2) - assert_np_arrays_equal(result, [np.nan, np.nan, -1]) + assert_array_equal(result, [np.nan, np.nan, -1]) result = func.MOM(values, timeperiod=3) - assert_np_arrays_equal(result, [np.nan, np.nan, np.nan]) + assert_array_equal(result, [np.nan, np.nan, np.nan]) result = func.MOM(values, timeperiod=4) - assert_np_arrays_equal(result, [np.nan, np.nan, np.nan]) + assert_array_equal(result, [np.nan, np.nan, np.nan]) + -def test_BBANDS(): - upper, middle, lower = func.BBANDS(series, timeperiod=20, - nbdevup=2.0, nbdevdn=2.0, - matype=talib.MA_Type.EMA) +def test_BBANDS(series): + upper, middle, lower = func.BBANDS( + series, + timeperiod=20, + nbdevup=2.0, + nbdevdn=2.0, + matype=talib.MA_Type.EMA + ) i = np.where(~np.isnan(upper))[0][0] assert len(upper) == len(middle) == len(lower) == len(series) - #assert abs(upper[i + 0] - 98.0734) < 1e-3 + # assert abs(upper[i + 0] - 98.0734) < 1e-3 assert abs(middle[i + 0] - 92.8910) < 1e-3 assert abs(lower[i + 0] - 87.7086) < 1e-3 - #assert abs(upper[i + 13] - 93.674) < 1e-3 + # assert abs(upper[i + 13] - 93.674) < 1e-3 assert abs(middle[i + 13] - 87.679) < 1e-3 assert abs(lower[i + 13] - 81.685) < 1e-3 -def test_DEMA(): + +def test_DEMA(series): result = func.DEMA(series) i = np.where(~np.isnan(result))[0][0] assert len(series) == len(result) @@ -113,13 +129,15 @@ def test_DEMA(): assert abs(result[i + 3] - 87.089) < 1e-3 assert abs(result[i + 4] - 87.656) < 1e-3 -def test_EMAEMA(): + +def test_EMAEMA(series): result = func.EMA(series, timeperiod=2) result = func.EMA(result, timeperiod=2) i = np.where(~np.isnan(result))[0][0] assert len(series) == len(result) assert i == 2 + def test_CDL3BLACKCROWS(): o = np.array([39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 40.32, 40.51, 38.09, 35.00, 27.66, 30.80]) h = np.array([40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 41.69, 40.84, 38.12, 35.50, 31.74, 32.51]) @@ -127,7 +145,8 @@ def test_CDL3BLACKCROWS(): c = np.array([40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.29, 40.46, 37.08, 33.37, 30.03, 31.46, 28.31]) result = func.CDL3BLACKCROWS(o, h, l, c) - assert_np_arrays_equal(result, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100, 0, 0]) + assert_array_equal(result, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -100, 0, 0]) + def test_RSI(): a = np.array([0.00000024, 0.00000024, 0.00000024, @@ -138,24 +157,26 @@ def test_RSI(): 0.00000024, 0.00000024, 0.00000023, 0.00000023, 0.00000023], dtype='float64') result = func.RSI(a, 10) - assert_np_arrays_equal(result, [np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,0,0,0,0,0,0,0,0,0,0]) + assert_array_equal(result, [np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,0,0,0,0,0,0,0,0,0,0]) result = func.RSI(a * 100000, 10) - assert_np_arrays_equal(result, [np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,33.333333333333329,51.351351351351347,39.491916859122398,51.84807024709005,42.25953803191981,52.101824405061215,52.101824405061215,43.043664867691085,43.043664867691085,43.043664867691085]) + assert_array_equal(result, [np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,np.nan,33.333333333333329,51.351351351351347,39.491916859122398,51.84807024709005,42.25953803191981,52.101824405061215,52.101824405061215,43.043664867691085,43.043664867691085,43.043664867691085]) + def test_MAVP(): a = np.array([1,5,3,4,7,3,8,1,4,6], dtype=float) b = np.array([2,4,2,4,2,4,2,4,2,4], dtype=float) result = func.MAVP(a, b, minperiod=2, maxperiod=4) - assert_np_arrays_equal(result, [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75]) + assert_array_equal(result, [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75]) sma2 = func.SMA(a, 2) - assert_np_arrays_equal(result[4::2], sma2[4::2]) + assert_array_equal(result[4::2], sma2[4::2]) sma4 = func.SMA(a, 4) - assert_np_arrays_equal(result[3::2], sma4[3::2]) + assert_array_equal(result[3::2], sma4[3::2]) result = func.MAVP(a, b, minperiod=2, maxperiod=3) - assert_np_arrays_equal(result, [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665]) + assert_array_equal(result, [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665]) sma3 = func.SMA(a, 3) - assert_np_arrays_equal(result[2::2], sma2[2::2]) - assert_np_arrays_equal(result[3::2], sma3[3::2]) + assert_array_equal(result[2::2], sma2[2::2]) + assert_array_equal(result[3::2], sma3[3::2]) + def test_MAXINDEX(): import talib as func @@ -163,7 +184,7 @@ def test_MAXINDEX(): a = np.array([1., 2, 3, 4, 5, 6, 7, 8, 7, 7, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 15]) b = func.MA(a, 10) c = func.MAXINDEX(b, 10) - assert_np_arrays_equal(c, [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,16,16,21]) + assert_array_equal(c, [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,16,16,16,21]) d = np.array([1., 2, 3]) e = func.MAXINDEX(d, 10) - assert_np_arrays_equal(e, [0,0,0]) + assert_array_equal(e, [0,0,0]) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index d1120ca0d..78f1a1236 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -1,49 +1,51 @@ import numpy as np +from numpy.testing import assert_array_equal import pandas as pd import talib -from talib.test_data import series, assert_np_arrays_equal + def test_MOM(): values = pd.Series([90.0,88.0,89.0], index=[10, 20, 30]) result = talib.MOM(values, timeperiod=1) assert isinstance(result, pd.Series) - assert_np_arrays_equal(result.values, [np.nan, -2, 1]) - assert_np_arrays_equal(result.index, [10, 20, 30]) + assert_array_equal(result.values, [np.nan, -2, 1]) + assert_array_equal(result.index, [10, 20, 30]) result = talib.MOM(values, timeperiod=2) assert isinstance(result, pd.Series) - assert_np_arrays_equal(result.values, [np.nan, np.nan, -1]) - assert_np_arrays_equal(result.index, [10, 20, 30]) + assert_array_equal(result.values, [np.nan, np.nan, -1]) + assert_array_equal(result.index, [10, 20, 30]) result = talib.MOM(values, timeperiod=3) assert isinstance(result, pd.Series) - assert_np_arrays_equal(result.values, [np.nan, np.nan, np.nan]) - assert_np_arrays_equal(result.index, [10, 20, 30]) + assert_array_equal(result.values, [np.nan, np.nan, np.nan]) + assert_array_equal(result.index, [10, 20, 30]) result = talib.MOM(values, timeperiod=4) assert isinstance(result, pd.Series) - assert_np_arrays_equal(result.values, [np.nan, np.nan, np.nan]) - assert_np_arrays_equal(result.index, [10, 20, 30]) + assert_array_equal(result.values, [np.nan, np.nan, np.nan]) + assert_array_equal(result.index, [10, 20, 30]) + def test_MAVP(): a = pd.Series([1,5,3,4,7,3,8,1,4,6], index=range(10, 20), dtype=float) b = pd.Series([2,4,2,4,2,4,2,4,2,4], index=range(20, 30), dtype=float) result = talib.MAVP(a, b, minperiod=2, maxperiod=4) assert isinstance(result, pd.Series) - assert_np_arrays_equal(result.values, [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75]) - assert_np_arrays_equal(result.index, range(10, 20)) + assert_array_equal(result.values, [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75]) + assert_array_equal(result.index, range(10, 20)) sma2 = talib.SMA(a, 2) assert isinstance(sma2, pd.Series) - assert_np_arrays_equal(sma2.index, range(10, 20)) - assert_np_arrays_equal(result.values[4::2], sma2.values[4::2]) + assert_array_equal(sma2.index, range(10, 20)) + assert_array_equal(result.values[4::2], sma2.values[4::2]) sma4 = talib.SMA(a, 4) assert isinstance(sma4, pd.Series) - assert_np_arrays_equal(sma4.index, range(10, 20)) - assert_np_arrays_equal(result.values[3::2], sma4.values[3::2]) + assert_array_equal(sma4.index, range(10, 20)) + assert_array_equal(result.values[3::2], sma4.values[3::2]) result = talib.MAVP(a, b, minperiod=2, maxperiod=3) assert isinstance(result, pd.Series) - assert_np_arrays_equal(result.values, [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665]) - assert_np_arrays_equal(result.index, range(10, 20)) + assert_array_equal(result.values, [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665]) + assert_array_equal(result.index, range(10, 20)) sma3 = talib.SMA(a, 3) assert isinstance(sma3, pd.Series) - assert_np_arrays_equal(sma3.index, range(10, 20)) - assert_np_arrays_equal(result.values[2::2], sma2.values[2::2]) - assert_np_arrays_equal(result.values[3::2], sma3.values[3::2]) + assert_array_equal(sma3.index, range(10, 20)) + assert_array_equal(result.values[2::2], sma2.values[2::2]) + assert_array_equal(result.values[3::2], sma3.values[3::2]) diff --git a/tests/test_polars.py b/tests/test_polars.py index 60f390e80..5bb7f7d24 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -1,44 +1,46 @@ import numpy as np +from numpy.testing import assert_array_equal import polars as pl import talib from talib import abstract -from talib.test_data import series, assert_np_arrays_equal def test_MOM(): values = pl.Series([90.0,88.0,89.0]) result = talib.MOM(values, timeperiod=1) assert isinstance(result, pl.Series) - assert_np_arrays_equal(result.to_numpy(), [np.nan, -2, 1]) + assert_array_equal(result.to_numpy(), [np.nan, -2, 1]) result = talib.MOM(values, timeperiod=2) assert isinstance(result, pl.Series) - assert_np_arrays_equal(result.to_numpy(), [np.nan, np.nan, -1]) + assert_array_equal(result.to_numpy(), [np.nan, np.nan, -1]) result = talib.MOM(values, timeperiod=3) assert isinstance(result, pl.Series) - assert_np_arrays_equal(result.to_numpy(), [np.nan, np.nan, np.nan]) + assert_array_equal(result.to_numpy(), [np.nan, np.nan, np.nan]) result = talib.MOM(values, timeperiod=4) assert isinstance(result, pl.Series) - assert_np_arrays_equal(result.to_numpy(), [np.nan, np.nan, np.nan]) + assert_array_equal(result.to_numpy(), [np.nan, np.nan, np.nan]) + def test_MAVP(): a = pl.Series([1,5,3,4,7,3,8,1,4,6], dtype=pl.Float64) b = pl.Series([2,4,2,4,2,4,2,4,2,4], dtype=pl.Float64) result = talib.MAVP(a, b, minperiod=2, maxperiod=4) assert isinstance(result, pl.Series) - assert_np_arrays_equal(result.to_numpy(), [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75]) + assert_array_equal(result.to_numpy(), [np.nan,np.nan,np.nan,3.25,5.5,4.25,5.5,4.75,2.5,4.75]) sma2 = talib.SMA(a, 2) assert isinstance(sma2, pl.Series) - assert_np_arrays_equal(result.to_numpy()[4::2], sma2.to_numpy()[4::2]) + assert_array_equal(result.to_numpy()[4::2], sma2.to_numpy()[4::2]) sma4 = talib.SMA(a, 4) assert isinstance(sma4, pl.Series) - assert_np_arrays_equal(result.to_numpy()[3::2], sma4.to_numpy()[3::2]) + assert_array_equal(result.to_numpy()[3::2], sma4.to_numpy()[3::2]) result = talib.MAVP(a, b, minperiod=2, maxperiod=3) assert isinstance(result, pl.Series) - assert_np_arrays_equal(result.to_numpy(), [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665]) + assert_array_equal(result.to_numpy(), [np.nan,np.nan,4,4,5.5,4.666666666666667,5.5,4,2.5,3.6666666666666665]) sma3 = talib.SMA(a, 3) assert isinstance(sma3, pl.Series) - assert_np_arrays_equal(result.to_numpy()[2::2], sma2.to_numpy()[2::2]) - assert_np_arrays_equal(result.to_numpy()[3::2], sma3.to_numpy()[3::2]) + assert_array_equal(result.to_numpy()[2::2], sma2.to_numpy()[2::2]) + assert_array_equal(result.to_numpy()[3::2], sma3.to_numpy()[3::2]) + def test_TEVA(): size = 50 @@ -57,7 +59,7 @@ def test_TEVA(): inputs = abstract.TEMA.get_input_arrays() assert inputs.columns == df.columns for column in df.columns: - assert_np_arrays_equal(inputs[column].to_numpy(), df[column].to_numpy()) + assert_array_equal(inputs[column].to_numpy(), df[column].to_numpy()) tema2 = abstract.TEMA(df, timeperiod=9) assert isinstance(tema2, pl.Series) @@ -65,6 +67,6 @@ def test_TEVA(): inputs = abstract.TEMA.get_input_arrays() assert inputs.columns == df.columns for column in df.columns: - assert_np_arrays_equal(inputs[column].to_numpy(), df[column].to_numpy()) + assert_array_equal(inputs[column].to_numpy(), df[column].to_numpy()) - assert_np_arrays_equal(tema1.to_numpy(), tema2.to_numpy()) + assert_array_equal(tema1.to_numpy(), tema2.to_numpy()) diff --git a/tests/test_stream.py b/tests/test_stream.py index 2374640b4..68fb18602 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -4,6 +4,7 @@ import talib from talib import stream + def test_streaming(): a = np.array([1,1,2,3,5,8,13], dtype=float) r = stream.MOM(a, timeperiod=1) @@ -21,6 +22,7 @@ def test_streaming(): r = stream.MOM(a, timeperiod=7) assert np.isnan(r) + def test_streaming_pandas(): a = pd.Series([1,1,2,3,5,8,13]) r = stream.MOM(a, timeperiod=1) @@ -38,6 +40,7 @@ def test_streaming_pandas(): r = stream.MOM(a, timeperiod=7) assert np.isnan(r) + def test_CDL3BLACKCROWS(): o = np.array([39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 40.32, 40.51, 38.09, 35.00]) h = np.array([40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 41.69, 40.84, 38.12, 35.50]) @@ -47,6 +50,7 @@ def test_CDL3BLACKCROWS(): r = stream.CDL3BLACKCROWS(o, h, l, c) assert r == -100 + def test_CDL3BLACKCROWS_pandas(): o = pd.Series([39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 39.00, 40.32, 40.51, 38.09, 35.00]) h = pd.Series([40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 40.84, 41.69, 40.84, 38.12, 35.50]) @@ -56,6 +60,7 @@ def test_CDL3BLACKCROWS_pandas(): r = stream.CDL3BLACKCROWS(o, h, l, c) assert r == -100 + def test_MAXINDEX(): a = np.array([1., 2, 3, 4, 5, 6, 7, 8, 7, 7, 3, 4, 5, 6, 7, 8, 9, 2, 3, 4, 5, 15]) r = stream.MAXINDEX(a, 10)