MLPACK  1.0.10
validation_RMSE_termination.hpp
Go to the documentation of this file.
1 
22 #ifndef VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
23 #define VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
24 
25 #include <mlpack/core.hpp>
26 
27 namespace mlpack
28 {
29 namespace amf
30 {
31 template <class MatType>
33 {
34  public:
36  size_t num_test_points,
37  double tolerance = 1e-5,
38  size_t maxIterations = 10000,
39  size_t reverseStepTolerance = 3)
42  num_test_points(num_test_points),
44  {
45  size_t n = V.n_rows;
46  size_t m = V.n_cols;
47 
48  test_points.zeros(num_test_points, 3);
49 
50  for(size_t i = 0; i < num_test_points; i++)
51  {
52  double t_val;
53  size_t t_row;
54  size_t t_col;
55  do
56  {
57  t_row = rand() % n;
58  t_col = rand() % m;
59  } while((t_val = V(t_row, t_col)) == 0);
60 
61  test_points(i, 0) = t_row;
62  test_points(i, 1) = t_col;
63  test_points(i, 2) = t_val;
64  V(t_row, t_col) = 0;
65  }
66  }
67 
68  void Initialize(const MatType& /* V */)
69  {
70  iteration = 1;
71 
72  rmse = DBL_MAX;
73  rmseOld = DBL_MAX;
74 
75  c_index = 0;
76  c_indexOld = 0;
77 
78  reverseStepCount = 0;
79  isCopy = false;
80  }
81 
82  bool IsConverged(arma::mat& W, arma::mat& H)
83  {
84  // Calculate norm of WH after each iteration.
85  arma::mat WH;
86 
87  WH = W * H;
88 
89  if (iteration != 0)
90  {
91  rmseOld = rmse;
92  rmse = 0;
93  for(size_t i = 0; i < num_test_points; i++)
94  {
95  size_t t_row = test_points(i, 0);
96  size_t t_col = test_points(i, 1);
97  double t_val = test_points(i, 2);
98  double temp = (t_val - WH(t_row, t_col));
99  temp *= temp;
100  rmse += temp;
101  }
103  rmse = sqrt(rmse);
104  }
105 
106  iteration++;
107 
108  if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
109  {
110  if(reverseStepCount == 0 && isCopy == false)
111  {
112  isCopy = true;
113  this->W = W;
114  this->H = H;
116  c_index = rmse;
117  }
119  }
120  else
121  {
122  reverseStepCount = 0;
123  if(rmse <= c_indexOld && isCopy == true)
124  {
125  isCopy = false;
126  }
127  }
128 
130  {
131  if(isCopy)
132  {
133  W = this->W;
134  H = this->H;
135  rmse = c_index;
136  }
137  return true;
138  }
139  else return false;
140  }
141 
142  const double& Index() { return rmse; }
143 
144  const size_t& Iteration() { return iteration; }
145 
146  const size_t& MaxIterations() { return maxIterations; }
147 
148  private:
149  double tolerance;
152  size_t iteration;
153 
154  arma::Mat<double> test_points;
155 
156  double rmseOld;
157  double rmse;
158 
161 
162  bool isCopy;
163  arma::mat W;
164  arma::mat H;
165  double c_indexOld;
166  double c_index;
167 };
168 
169 } // namespace amf
170 } // namespace mlpack
171 
172 
173 #endif // VALIDATION_RMSE_TERMINATION_HPP_INCLUDED
174 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:31
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)