-
Notifications
You must be signed in to change notification settings - Fork 19
/
Program.cs
124 lines (109 loc) · 4.69 KB
/
Program.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
using System;
using System.IO;
using Microsoft.ML;
//Attempt to predict stock prices
namespace StockPricePrediction
{
class Program
{
//Training data csv
static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "AAPL-Train.csv");
//Testing data csv
static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "AAPL-Test.csv");
// </Snippet2>
static void Main(string[] args)
{
Console.WriteLine(Environment.CurrentDirectory);
// <Snippet3>
MLContext mlContext = new MLContext(seed: 0);
// </Snippet3>
// <Snippet5>
var model = Train(mlContext, _trainDataPath);
// </Snippet5>
// <Snippet14>
Evaluate(mlContext, model);
// </Snippet14>
// <Snippet20>
TestSinglePrediction(mlContext, model);
// </Snippet20>
}
public static ITransformer Train(MLContext mlContext, string dataPath)
{
// <Snippet6>
IDataView dataView = mlContext.Data.LoadFromTextFile<StockPrice>(dataPath, hasHeader: true, separatorChar: ',');
// </Snippet6>
// <Snippet7>
var pipeline = mlContext.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: "Close")
// </Snippet7>
// <Snippet8>
// </Snippet8>
// <Snippet9>
.Append(mlContext.Transforms.Concatenate("Features", "Open","High","Low","Volume"))
// </Snippet9>
// <Snippet10>
.Append(mlContext.Regression.Trainers.FastTree());
// </Snippet10>
Console.WriteLine("=============== Create and Train the Model ===============");
// <Snippet11>
var model = pipeline.Fit(dataView);
// </Snippet11>
Console.WriteLine("=============== End of training ===============");
Console.WriteLine();
// <Snippet12>
return model;
// </Snippet12>
}
private static void Evaluate(MLContext mlContext, ITransformer model)
{
// <Snippet15>
IDataView dataView = mlContext.Data.LoadFromTextFile<StockPrice>(_testDataPath, hasHeader: true, separatorChar: ',');
// </Snippet15>
// <Snippet16>
var predictions = model.Transform(dataView);
// </Snippet16>
// <Snippet17>
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
// </Snippet17>
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Model quality metrics evaluation ");
Console.WriteLine($"*------------------------------------------------");
// <Snippet18>
Console.WriteLine($"* RSquared Score: {metrics.RSquared:0.##}");
// </Snippet18>
// <Snippet19>
Console.WriteLine($"* Root Mean Squared Error: {metrics.RootMeanSquaredError:#.##}");
// </Snippet19>
Console.WriteLine($"*************************************************");
}
private static void TestSinglePrediction(MLContext mlContext, ITransformer model)
{
//Prediction test
// Create prediction function and make prediction.
// <Snippet22>
var predictionFunction = mlContext.Model.CreatePredictionEngine<StockPrice, StockPriceClosePrediction>(model);
// </Snippet22>
//Sample:
//Date,Open,High,Low,Close,Volume
//20191101,249.56,255.93,249.16,255.82,29671000
// <Snippet23>
var stockPriceSample = new StockPrice()
{
Open= 249.56f,
High = 255.93f,
Low = 249.16f,
Close= 0, // To predict. Actual/Observed = 255.82
Volume = 29671000
};
// </Snippet23>
// <Snippet24>
var prediction = predictionFunction.Predict(stockPriceSample);
// </Snippet24>
// <Snippet25>
Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted Close Price: {prediction.Close:0.####}, actual close: 255.82");
Console.WriteLine($"**********************************************************************");
// </Snippet25>
}
}
}