Skip to content
This repository has been archived by the owner on Jan 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #197 from ndawe/master
Browse files Browse the repository at this point in the history
[MRG] multiclass TMVA example plot
  • Loading branch information
ndawe committed May 5, 2015
2 parents a24e1c8 + 89a19af commit 4f858e0
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 23 deletions.
8 changes: 0 additions & 8 deletions AUTHORS

This file was deleted.

7 changes: 7 additions & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Authors ordered by first contribution:

* Piti Ongmongkolkul ([email protected])
* Noel Dawe ([email protected])
* Christoph Deil
* Peter Waller ([email protected])
* Giordon Stark ([email protected], [email protected])
57 changes: 44 additions & 13 deletions examples/tmva/multiclass.py → examples/tmva/plot_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
Multiclass Classification with NumPy and TMVA
=============================================
"""
from __future__ import print_function
from array import array
import numpy as np
from numpy.random import RandomState
from root_numpy.tmva import add_classification_events, evaluate_reader
import matplotlib.pyplot as plt
from ROOT import TMVA, TFile, TCut

plt.style.use('ggplot')
RNG = RandomState(42)

# Construct an example multiclass dataset
n_vars = 5
n_events = 1000
class_0 = RNG.multivariate_normal(
np.ones(n_vars) * -3, np.diag(np.ones(n_vars)), n_events)
[-2, -2], np.diag([1, 1]), n_events)
class_1 = RNG.multivariate_normal(
np.zeros(n_vars), np.diag(np.ones(n_vars)), n_events)
[0, 2], np.diag([1, 1]), n_events)
class_2 = RNG.multivariate_normal(
np.ones(n_vars) * 3, np.diag(np.ones(n_vars)), n_events)
[2, -2], np.diag([1, 1]), n_events)
X = np.concatenate([class_0, class_1, class_2])
y = np.ones(X.shape[0])
w = RNG.randint(1, 10, n_events * 3)
Expand All @@ -38,7 +38,7 @@
factory = TMVA.Factory('classifier', output,
'AnalysisType=Multiclass:'
'!V:Silent:!DrawProgressBar')
for n in range(n_vars):
for n in range(2):
factory.AddVariable('f{0}'.format(n), 'F')

# Call root_numpy's utility functions to add events from the arrays
Expand All @@ -48,17 +48,48 @@
# Train a BDT
factory.PrepareTrainingAndTestTree(TCut('1'), 'NormMode=EqualNumEvents')
factory.BookMethod('BDT', 'BDTG',
'nCuts=20:NTrees=10:MaxDepth=3:'
'nCuts=20:NTrees=20:MaxDepth=4:'
'BoostType=Grad:Shrinkage=0.10')
factory.TrainAllMethods()

# Classify the test dataset with the BDT
reader = TMVA.Reader()
for n in range(n_vars):
for n in range(2):
reader.AddVariable('f{0}'.format(n), array('f', [0.]))
reader.BookMVA('BDT', 'weights/classifier_BDTG.weights.xml')
scores = evaluate_reader(reader, 'BDT', X_test)
print("class probabilities:")
print(scores)
print("class probabilties should sum to 1:")
print(np.sum(scores,axis=1))
class_proba = evaluate_reader(reader, 'BDT', X_test)

# Plot the decision boundaries
plot_colors = "rgb"
plot_step = 0.02
class_names = "ABC"
cmap = plt.get_cmap('Paired')

plt.figure(figsize=(5, 5))
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
np.arange(y_min, y_max, plot_step))

Z = evaluate_reader(reader, 'BDT', np.c_[xx.ravel(), yy.ravel()])
Z = np.argmax(Z, axis=1) - 1
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=cmap, vmin=Z.min(), vmax=Z.max(),
levels=np.linspace(Z.min(), Z.max(), 50))
plt.axis("tight")

# Plot the training points
for i, n, c in zip(range(3), class_names, plot_colors):
idx = np.where(y == i)
plt.scatter(X[idx, 0], X[idx, 1],
c=c, cmap=cmap,
label="Class %s" % n)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='upper right')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Decision Boundary')

plt.tight_layout()
plt.show()
4 changes: 3 additions & 1 deletion examples/tmva/plot_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from ROOT import TMVA, TFile, TCut
from array import array

# Create an example regression dataset
plt.style.use('ggplot')
RNG = np.random.RandomState(1)

# Create an example regression dataset
X = np.linspace(0, 6, 100)[:, np.newaxis]
y = np.sin(X).ravel() + \
np.sin(6 * X).ravel() + \
Expand Down
3 changes: 2 additions & 1 deletion examples/tmva/plot_twoclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from root_numpy.tmva import add_classification_events, evaluate_reader
from ROOT import TMVA, TFile, TCut

plt.style.use('ggplot')
RNG = RandomState(42)

# Construct an example dataset for binary classification
Expand Down Expand Up @@ -100,7 +101,7 @@
range=plot_range,
facecolor=c,
label='Class %s' % n,
alpha=.5)
alpha=.5, histtype='stepfilled')
x1, x2, y1, y2 = plt.axis()
plt.axis((x1, x2, y1, y2 * 1.2))
plt.legend(loc='upper right')
Expand Down

0 comments on commit 4f858e0

Please sign in to comment.