Blog Archive

Monday, March 19, 2018

kaldi中的Vector和Matrix

Link:

http://blog.csdn.net/u013677156/article/details/79202271



kaldi中的Vector和Matrix
    Vector和Matrix是kaldi中最常用的数据类型之一。语音数据,提取的特征,计算的结果,都保存在Vector或者Matrix之中。按照字面意思,Vector是“向量”,它只有一行数据,是一维的。Matrix是“矩阵”,它有行与列两个维度。kaldi中的Vector和Matrix,可以做许多数学上的操作。比如点加或点乘(每个元素都加上一个数,或者乘以一个数),比如矩阵之间的乘法和矩阵的奇异分解等。kaldi中Vector和Matrix还可以做一些特殊操作,比如对每个元素取对数,对所有元素做softmax等。
一、首先介绍下Vector。
    在matrix/kaldi-vector.h中,定义了三个类:VectorBase、Vector和SubVector。其中,VectorBase是基类(父类),Vector和SubVector是派生类(子类)。VectorBase中的成员函数已经可以完成一个向量类的所有操作了,Vector类只是做了封装,定义了多种形式的构造函数,增加了resize操作等。
    VectorBase类中的数据成员十分简单,就两个成员。一个指针data_指向存放数据的内存,一个整数dim_指示元素的个数。
VectorBase中的函数成员比较多,但基本可以分为两类。一类是基本的、简单的操作。例如SetZero函数,用以设置全部数据为0;例如max函数,返回向量中的最大值。另一类是偏应用的函数或操作。比如,ApplySoftMax函数,提供softmax操作;比如Norm函数,计算范数。

    下面的代码,有利于理解VectorBase的各种性质。注意,为了方面阅读和理解,对源代码做了修改。
  1. template<typename Real>  
  2. class VectorBase {  
  3.   // *******数据成员,data_表示内存地址,dim_表示元素个数*******  
  4.   Real* data_;  
  5.   MatrixIndexT dim_;  
  6.   explicit VectorBase(): data_(NULL), dim_(0) { }  
  7.     
  8.   // **************第一类,比较基础和简单的函数*******************  
  9.   // set类函数,设置全部值为0、特定值或者某种分布的随机值  
  10.   void SetZero();  
  11.   void Set(Real f);  
  12.   void SetRandn();  
  13.   void SetRandUniform();  
  14.   
  15.   // 返回元素个数dim_,返回元素地址data_,返回占用内存大小,重载()操作等  
  16.   inline MatrixIndexT Dim() const { return dim_; }  
  17.   inline Real* Data() { return data_; }  
  18.   inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); }  
  19.   inline Real operator() (MatrixIndexT i) const { return *(data_ + i);  }  
  20.   inline Real & operator() (MatrixIndexT i) {   return *(data_ + i);  }  
  21.   Real Max() const;  
  22.   Real Min() const;  
  23.   Real Sum() const;  
  24.   Real SumLog() const;  
  25.     
  26.   // 拷贝向量或者矩阵(全部或者局部,例如一行)的内容,来作为data_  
  27.   void CopyFromVec(const VectorBase<Real> &v);  
  28.   void CopyFromPacked(const PackedMatrix<Real> &M);  
  29.   void CopyRowsFromMat(const MatrixBase<Real> &M);  
  30.   void CopyColsFromMat(const MatrixBase<Real> &M);  
  31.   void CopyRowFromMat(const MatrixBase<Real> &M, MatrixIndexT row);  
  32.   void CopyDiagFromMat(const MatrixBase<Real> &M);  
  33.     
  34.   // **************第二类,偏应用的操作和函数*******************  
  35.   void Add(Real c);      /// data_[i] += c;  
  36.   void Scale(Real c);    /// data_[i] *= c; cblas_Xscal(dim_, c, data_, 1);  
  37.   void ApplyLog();       /// data_[i] = Log(data_[i])  
  38.   void ApplyExp();       /// data_[i] = Exp(data_[i])  
  39.   void ApplyAbs();       /// data_[i] = abs(data_[i])  
  40.   void InvertElements(); /// data_[i] = 1 / data_[i]  
  41.   void ApplyPow(Real power);  // 求指数  
  42.   Real Norm(Real p) const;    // 求p阶范数  
  43.   void MulElements(const VectorBase<Real> &v); //data_[i] *= v.data_[i];  
  44.   void DivElements(const VectorBase<Real> &v); //data_[i] /= v.data_[i];  
  45.   
  46.   //各种形式的矩阵操作,一般调用BLAS,例如 AddVec: *this = *this + alpha * rv   
  47.   void AddVec(const Real alpha, const VectorBase<Real> &v);  
  48.   void AddVec2(const Real alpha, const VectorBase<Real> &v); // rv^2  
  49.   void AddMatVec(...);  //  this <-- beta*this + alpha*M*v.  
  50.   void AddSpVec(...)    //  this <-- beta*this + alpha*M*v.  
  51.   void AddTpVec(...)    //  this <-- beta*this + alpha*M*v.  
  52.   void AddVecVec(...);  //  this <-- alpha * v .* r + beta*this .  
  53.   void AddVecDivVec(...);// this <---- alpha*v/r + beta*this  
  54.   void MulTp(...);      //  *this <-- *this *M  
  55.     
  56.   //使用softmax: \f$ x(i) = exp(x(i)) / \sum_i exp(x(i)) \f$  
  57.   Real ApplySoftMax(){  
  58.     Real max = this->Max(), sum = 0.0;  
  59.     for (MatrixIndexT i = 0; i < dim_; i++)  
  60.       sum += (data_[i] = Exp(data_[i] - max));  
  61.     this->Scale(1.0 / sum);  
  62.     return max + Log(sum);  
  63.   }  
  64.   void Tanh(const VectorBase<Real> &src);  
  65.   void Sigmoid(const VectorBase<Real> &src);  
  66. }; // class VectorBase  
  67.   
  68.   
  69. template<typename Real>  
  70. class Vector: public VectorBase<Real> {  
  71.  public:  
  72.   // 各种构造函数和赋值操作。  
  73.   Vector(): VectorBase<Real>() {}  
  74.   explicit Vector(const MatrixIndexT s, MatrixResizeType resize_type)  
  75.       : VectorBase<Real>() {  Resize(s, resize_type);  }  
  76.   Vector(const Vector<Real> &v) : VectorBase<Real>()  {   
  77.     Resize(v.Dim(), kUndefined);  
  78.     this->CopyFromVec(v);  }  
  79.   explicit Vector(const VectorBase<Real> &v) : VectorBase<Real>() {  
  80.     Resize(v.Dim(), kUndefined);  
  81.     this->CopyFromVec(v);  }  
  82.   Vector<Real> &operator = (const Vector<Real> &other) {  
  83.     Resize(other.Dim(), kUndefined);  
  84.     this->CopyFromVec(other);  
  85.     return *this; }  
  86.   
  87.   // 新增的Swap、Resize和RemoveElement操作  
  88.   void Swap(Vector<Real> *other);  
  89.   void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero);  
  90.   void RemoveElement(MatrixIndexT i);  
  91.   
  92.  private:  
  93.   void Init(const MatrixIndexT dim);  
  94.   void Destroy();  
  95. };  
  96.   
  97. template<typename Real>  
  98. class SubVector : public VectorBase<Real> {  
  99.  public:  
  100.   //SubVector不分配内存,它使用其他VectorBase的数据,可以看作是“引用”。  
  101.   // 下面是各种版本的构造函数。  
  102.   SubVector(const VectorBase<Real> &t, const MatrixIndexT origin,  
  103.             const MatrixIndexT length) : VectorBase<Real>() {  
  104.      VectorBase<Real>::data_ = const_cast<Real*> (t.Data()+origin);  
  105.     VectorBase<Real>::dim_   = length;  
  106.   }  
  107.   SubVector(const PackedMatrix<Real> &M) {  
  108.     VectorBase<Real>::data_ = const_cast<Real*> (M.Data());  
  109.     VectorBase<Real>::dim_   = (M.NumRows()*(M.NumRows()+1))/2;  
  110.   }  
  111.   SubVector(const SubVector &other) : VectorBase<Real> () {// Copy constructor  
  112.     VectorBase<Real>::data_ = other.data_;  
  113.     VectorBase<Real>::dim_ = other.dim_;  
  114.   }  
  115.   SubVector(Real *data, MatrixIndexT length) : VectorBase<Real> () {  
  116.     VectorBase<Real>::data_ = data;  
  117.     VectorBase<Real>::dim_   = length;  
  118.   }  
  119.   SubVector(const MatrixBase<Real> &matrix, MatrixIndexT row) {  
  120.     VectorBase<Real>::data_ = const_cast<Real*>(matrix.RowData(row));  
  121.     VectorBase<Real>::dim_   = matrix.NumCols();  
  122.   }  
  123.   ~SubVector() {}  ///< Destructor (does nothing; no pointers are owned here).  
  124.   
  125.  private:  
  126.   /// Disallow assignment operator.  
  127.   SubVector & operator = (const SubVector &other) {}  
  128. };  

 通过上面的代码,我们可以看出,Vector对VectorBase并未做太多的扩展,它们的功能基本一样。SubVector可以看作一种“引用”,它自身并不分配内存保存数据,而是指向了其他的对象中的数据。


二、简单介绍下Matrix。
    跟Vector类似,在在matrix/kaldi-matrix.h中,定义了三个类:MatrixBase、Matrix和SubMatrix。MatrixBase是基类,另外两个是派生类。MatrixBase已经实现了非常多的方法。Matrix只是在基类的基础上,加了少数几个函数,比如Swap和RemoveRow等,这点跟Vector与VectorBase的关系一样。
    MatrixBase中,数据成员并不多,大部分也容易理解。比如,整数num_rows_和num_cols_表示矩阵的行数和列数,指针data_指向保存数据的内存地址。这里有另外一个整型变量stride_需要注意。stride_保存的是正真的一行的个数。这里的意思是,一个矩阵,一行可能可以存放许多数据(stride_个),但可以不放满,只放num_cols_个。这时,一部分空间是浪费的。当然,一般部分情况下,num_cols_和stride_是一致的。
    在矩阵上面的操作要比向量上的操作多,所以Matrix中的成员函数比Vector中的多很多。
  1. template<typename Real>  
  2. class MatrixBase {  
  3.   //***************数据成员********************  
  4.   Real*   data_;             // data memory area  
  5.   MatrixIndexT  num_cols_;   // < Number of columns  
  6.   MatrixIndexT  num_rows_;   // < Number of rows  
  7.   MatrixIndexT  stride_;     // True number of columns   
  8.     
  9.   // 基本操作函数  
  10.   inline MatrixIndexT NumRows() const { return num_rows_; }  
  11.   inline MatrixIndexT NumCols() const { return num_cols_; }  
  12.   inline MatrixIndexT Stride() const {  return stride_; }  
  13.   inline Real* Data() const { return data_;  }  
  14.   inline Real* RowData(MatrixIndexT i) { return data_ + i * stride_;  }  
  15.   inline Real&  operator() ( r,  c) {return *(data_ + r * stride_ + c);  }  
  16.   size_t SizeInBytes() const {return num_rows_ * stride_ * sizeof(Real);}  
  17.   Real &Index (MatrixIndexT r, MatrixIndexT c) {  return (*this)(r, c); }  
  18.   
  19.   // set、max、min等函数,省略若干  
  20.   void SetZero();  
  21.   void Set(Real);  
  22.   Real Sum() const;  
  23.   Real Max() const;  
  24.   Real Min() const;  
  25.   bool IsZero(Real cutoff = 1.0e-05) const;  
  26.   
  27.   //Copy、SubVector、SubMatrix类函数,很多版本  
  28.   void CopyFromMat(const CompressedMatrix &M);  
  29.   void CopyRowsFromVec(const VectorBase<Real> &v);  
  30.   void CopyDiagFromVec(const VectorBase<Real> &v);  
  31.   inline SubVector<Real> Row(MatrixIndexT i);  
  32.   inline SubMatrix<Real> Range(...);  
  33.   
  34.   // 一些加减乘除操作,其他应用操作  
  35.   void MulElements(const MatrixBase<Real> &A);  
  36.   void DivElements(const MatrixBase<Real> &A);  
  37.   void Scale(Real alpha);  
  38.   void Max(const MatrixBase<Real> &A);  
  39.   void Min(const MatrixBase<Real> &A);  
  40.   void MulColsVec(const VectorBase<Real> &scale);  
  41.   void MulRowsVec(const VectorBase<Real> &scale);  
  42.   void Add(const Real alpha);  
  43.   void ApplyFloor(Real floor_val);  
  44.   void ApplyCeiling(Real ceiling_val);  
  45.   void ApplyLog();  
  46.   void ApplyExp();  
  47.   Real ApplySoftMax();  
  48.   void Sigmoid(const MatrixBase<Real> &src);  
  49.   
  50.   // 求正定矩阵、求逆;转置;特征分解;奇异值分解;矩阵运算  
  51.   Real LogDet(Real *det_sign = NULL) const;  
  52.   void Invert(Real *log_det = NULL, Real *det_sign = NULL,  
  53.               bool inverse_needed = true);  
  54.   void Transpose();  
  55.   void Eig(MatrixBase<Real> *P,  
  56.            VectorBase<Real> *eigs_real,  
  57.            VectorBase<Real> *eigs_imag) const;  
  58.   void Svd(VectorBase<Real> *s, MatrixBase<Real> *U,  
  59.            MatrixBase<Real> *Vt) const;  
  60.   void AddVecVec(...) //*this += alpha * a * b^T  
  61.   void AddMat(...) //*this += alpha * M  
  62.   void AddMatMatMat(...) //this <-- beta*this + alpha*A*B*C.  
  63.   void AddTpTp(...) //this <-- beta*this + alpha*A*B.  
  64. };  


'via Blog this'

No comments:

Post a Comment