Skip to content

Commit a5849e2

Browse files
committed
update comments
1 parent 9ea1a87 commit a5849e2

File tree

10 files changed

+191
-38
lines changed

10 files changed

+191
-38
lines changed

Doxyfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ EXTRACT_ALL = NO
441441
# be included in the documentation.
442442
# The default value is: NO.
443443

444-
EXTRACT_PRIVATE = YES
444+
EXTRACT_PRIVATE = NO
445445

446446
# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal
447447
# scope will be included in the documentation.

include/thundersvm/cmdparser.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,25 @@
1010

1111
#include "svmparam.h"
1212

13+
/**
14+
* @brief Command-line parser
15+
*/
1316
class CMDParser{
1417
public:
15-
CMDParser() : do_cross_validation(false), nr_fold(0), gpu_id(0) {};
18+
CMDParser() : do_cross_validation(false), nr_fold(0), gpu_id(0) {};
19+
1620
void parse_command_line(int argc, char **argv);
17-
void parse_python(int argc, char **argv);
18-
SvmParam param_cmd;
21+
22+
void parse_python(int argc, char **argv);
23+
24+
SvmParam param_cmd;
1925
bool do_cross_validation;
2026
int nr_fold;
21-
int gpu_id;
27+
int gpu_id;
2228
char svmtrain_input_file_name[1024];
23-
char svmpredict_input_file[1024];
24-
char svmpredict_output_file[1024];
25-
char svmpredict_model_file_name[1024];
29+
char svmpredict_input_file[1024];
30+
char svmpredict_output_file[1024];
31+
char svmpredict_model_file_name[1024];
2632
char model_file_name[1024];
2733
};
2834

include/thundersvm/dataset.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,71 @@
77

88
#include "thundersvm.h"
99
#include "syncdata.h"
10+
11+
/**
12+
* @brief Dataset reader
13+
*/
1014
class DataSet {
1115
public:
1216
struct node{
1317
node(int index, float_type value) : index(index), value(value) {}
18+
1419
int index;
1520
float_type value;
1621
};
22+
1723
typedef vector<vector<DataSet::node>> node2d;
1824

1925
DataSet();
2026

27+
/**
28+
* construct a dataset using given instances
29+
* @param instances given instances
30+
* @param n_features the number of features of given instances
31+
* @param y the label of each instances
32+
*/
2133
DataSet(const DataSet::node2d &instances, int n_features, const vector<float_type> &y);
34+
35+
///load dataset from file
2236
void load_from_file(string file_name);
37+
38+
///load dataset from python
2339
void load_from_python(float *y, char **x, int len);
40+
41+
///group instances in same class
2442
void group_classes(bool classification = true);
43+
2544
size_t n_instances() const;
45+
2646
size_t n_features() const;
47+
2748
size_t n_classes() const;
2849

50+
///the number of instances for each class
2951
const vector<int> &count() const;
3052

53+
///the start position of instances for each class
3154
const vector<int> &start() const;
3255

56+
///mapping logical label (0,1,2,3,...) to real label (maybe 2,4,5,6,...)
3357
const vector<int> &label() const;
3458

59+
///label for each instances, the instances are arranged as they are in file
3560
const vector<float_type> &y() const;
61+
3662
const node2d & instances() const;
63+
64+
///instances of class \f$y_i\f$
3765
const node2d instances(int y_i) const;
66+
67+
///instances of class \f$y_i\f$ and \f$y_j\f$
3868
const node2d instances(int y_i, int y_j) const;
69+
70+
///mapping instance index (after grouped) to the original index (in file)
3971
const vector<int> original_index() const;
72+
4073
const vector<int> original_index(int y_i) const;
74+
4175
const vector<int> original_index(int y_i, int y_j) const;
4276

4377
private:

include/thundersvm/kernelmatrix.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,42 @@
1010
#include "dataset.h"
1111
#include "svmparam.h"
1212

13+
/**
14+
* @brief The management class of kernel values.
15+
*/
1316
class KernelMatrix{
1417
public:
18+
/**
19+
* Create KernelMatrix with given instances (training data or support vectors).
20+
* @param instances the instances, either are training instances for training, or are support vectors for prediction.
21+
* @param param kernel_type in parm is used
22+
*/
1523
explicit KernelMatrix(const DataSet::node2d &instances, SvmParam param);
1624

25+
/**
26+
* return specific rows in kernel matrix
27+
* @param [in] idx the indices of the rows
28+
* @param [out] kernel_rows
29+
*/
1730
void get_rows(const SyncData<int> &idx, SyncData<float_type> &kernel_rows) const;
1831

32+
/**
33+
* return kernel values of each given instance and each instance stored in KernelMatrix
34+
* @param [in] instances the given instances
35+
* @param [out] kernel_rows
36+
*/
1937
void get_rows(const DataSet::node2d &instances, SyncData<float_type> &kernel_rows) const;
2038

39+
///return the diagonal elements of kernel matrix
2140
const SyncData<float_type> &diag() const;
2241

23-
size_t n_instances() const { return n_instances_; };//number of instances
24-
size_t n_features() const { return n_features_; };//number of features
42+
///the number of instances in KernelMatrix
43+
size_t n_instances() const { return n_instances_; };
44+
45+
///the maximum number of features of instances
46+
size_t n_features() const { return n_features_; }
47+
48+
///the number of non-zero features of all instances
2549
size_t nnz() const {return nnz_;};//number of nonzero
2650
private:
2751
KernelMatrix &operator=(const KernelMatrix &) const;

include/thundersvm/model/nusvr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include "svr.h"
99

1010
/**
11-
* @brief-Support Vector Machine for regression
11+
* @brief Support Vector Machine for regression
1212
*/
1313
class NuSVR : public SVR {
1414
public:

include/thundersvm/model/svr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
using std::map;
1313

1414
/**
15-
* Support Vector Machine for regression
15+
* @brief Support Vector Machine for regression
1616
*/
1717
class SVR : public SvmModel {
1818
public:

include/thundersvm/svmparam.h

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
#include "thundersvm.h"
99

10+
/**
11+
* @brief params for ThunderSVM
12+
*/
1013
struct SvmParam {
1114
SvmParam() {
1215
svm_type = C_SVC;
@@ -20,32 +23,40 @@ struct SvmParam {
2023
nr_weight = 0;
2124
}
2225

26+
/// SVM type
2327
enum SVM_TYPE {
2428
C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR
25-
}; /* svm_type */
29+
};
30+
/// kernel function type
2631
enum KERNEL_TYPE {
2732
LINEAR, POLY, RBF, SIGMOID/*, PRECOMPUTED*/
28-
}; /* kernel_type */
33+
};
2934
SVM_TYPE svm_type;
3035
KERNEL_TYPE kernel_type;
3136

32-
float_type C; //for regularization
33-
float_type gamma; //for rbf kernel
34-
float_type p; //for regression
35-
float_type nu; //for nu-SVM
36-
float_type epsilon; //stopping criteria
37-
int degree; //degree for polynomial kernel
38-
39-
float_type coef0; /* for poly/sigmoid */
40-
41-
/* these are for training only */
42-
// double cache_size; /* in MB */
43-
int nr_weight; /* for C_SVC */
44-
int *weight_label; /* for C_SVC */
45-
float_type *weight; /* for C_SVC */
46-
// int shrinking; /* use the shrinking heuristics */
47-
int probability; /* do probability estimates */
37+
///regularization parameter
38+
float_type C;
39+
///for RBF kernel
40+
float_type gamma;
41+
///for regression
42+
float_type p;
43+
///for \f$\nu\f$-SVM
44+
float_type nu;
45+
///stopping criteria
46+
float_type epsilon;
47+
///degree for polynomial kernel
48+
int degree;
49+
///for polynomial/sigmoid kernel
50+
float_type coef0;
51+
///for SVC
52+
int nr_weight;
53+
///for SVC
54+
int *weight_label;
55+
///for SVC
56+
float_type *weight;
57+
///do probability estimates
58+
int probability;
4859
static const char *kernel_type_name[6];
49-
static const char *svm_type_name[6]; /* svm_type */
60+
static const char *svm_type_name[6];
5061
};
5162
#endif //THUNDERSVM_SVMPARAM_H

include/thundersvm/syncdata.h

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,64 +8,98 @@
88
#include "thundersvm.h"
99
#include "syncmem.h"
1010

11+
/**
12+
* @brief Wrapper of SyncMem with a type
13+
* @tparam T type of element
14+
*/
1115
template<typename T>
1216
class SyncData : public el::Loggable {
1317
public:
18+
/**
19+
* initialize class that can store given count of elements
20+
* @param count the given count
21+
*/
1422
explicit SyncData(size_t count);
1523

1624
SyncData() : mem(nullptr), size_(0) {};
25+
1726
~SyncData();
1827

1928
const T *host_data() const;
29+
2030
const T *device_data() const;
2131

2232
T *host_data();
33+
2334
T *device_data();
2435

2536
void set_host_data(T *host_ptr){
2637
mem->set_host_data(host_ptr);
2738
}
39+
2840
void set_device_data(T *device_ptr){
29-
mem->set_device_data(device_ptr);
41+
mem->set_device_data(device_ptr);
3042
}
3143

3244
void to_host() const{
33-
mem->to_host();
45+
mem->to_host();
3446
}
47+
3548
void to_device() const{
36-
mem->to_device();
49+
mem->to_device();
3750
}
3851

52+
/**
53+
* random access operator
54+
* @param index the index of the elements
55+
* @return **host** element at the index
56+
*/
3957
const T &operator[](int index) const{
40-
return host_data()[index];
58+
return host_data()[index];
4159
}
60+
4261
T &operator[](int index){
43-
return host_data()[index];
62+
return host_data()[index];
4463
}
4564

65+
/**
66+
* copy device data. This will call to_device() implicitly.
67+
* @param source source device data pointer
68+
* @param count the count of elements
69+
*/
4670
void copy_from(const T *source, size_t count);
71+
4772
void copy_from(const SyncData<T> &source);
4873

74+
/**
75+
* set all elements to the given value. This method will set device data.
76+
* @param value
77+
*/
4978
void mem_set(const T &value);
5079

80+
/**
81+
* resize to a new size. This will also clear all data.
82+
* @param count
83+
*/
5184
void resize(size_t count);
5285

5386
size_t mem_size() const {//number of bytes
54-
return mem->size();
87+
return mem->size();
5588
}
5689

5790
size_t size() const {//number of values
5891
return size_;
5992
}
6093

6194
SyncMem::HEAD head() const{
62-
return mem->head();
95+
return mem->head();
6396
}
6497

6598
void log(el::base::type::ostream_t &ostream) const override;
6699

67100
private:
68101
SyncData<T> &operator=(const SyncData<T> &);
102+
69103
SyncData(const SyncData<T>&);
70104

71105
SyncMem *mem;

0 commit comments

Comments
 (0)