
College project "Projet IAD" master 1
git clone https://esimon.eu/repos/mi024.git
Log | Files | Refs | README

nmlp_base_iostream.ipp (5035B)

      1 /**
      2  * @file
      3  * @brief nmlp base classe iostream function definitions.
      4  */
      6 #include <cstddef>
      8 #include <boost/serialization/shared_ptr.hpp>
      9 #include <nmlp/CPUMatrix.h>
     10 #include <nmlp/CPUSparseMatrix.h>
     11 #include <nmlp/GPUMatrix.h>
     12 #include <nmlp/Matrix.h>
     13 #include <nmlp/Tensor.h>
     15 #include "nmlp_base_iostream.hpp"
     17 template<class Archive, class Matrix_type>
     18 void save_dense_matrix(Archive &ar, Matrix_type const &rhs, unsigned const){
     19 	Matrix_type &nonconst_rhs=const_cast<Matrix_type&>(rhs); // The interface of nmlp is const-inconsistent.
     20 	for(std::size_t y=0; y<nonconst_rhs.getNumberOfRows(); ++y)
     21 		for(std::size_t x=0; x<nonconst_rhs.getNumberOfColumns(); ++x){
     22 			float val=nonconst_rhs.getValue(y, x);
     23 			ar << val;
     24 		}
     25 }
     27 template<class Archive, class Matrix_type>
     28 void load_dense_matrix(Archive &ar, Matrix_type &rhs, unsigned const){
     29 	for(std::size_t y=0; y<rhs.getNumberOfRows(); ++y)
     30 		for(std::size_t x=0; x<rhs.getNumberOfColumns(); ++x){
     31 			float value;
     32 			ar >> value;
     33 			rhs.setValue(y, x, value);
     34 		}
     35 }
     37 template<class Archive, class Matrix_type>
     38 void load_construct_matrix_data(Archive &ar, Matrix_type *rhs, unsigned const){
     39 	int r, c;
     40 	ar >> r >> c;
     41 	new(rhs) Matrix_type(r, c);
     42 }
     44 template<class Archive, class Matrix_type>
     45 void save_construct_matrix_data(Archive &ar, Matrix_type const *rhs, unsigned const){
     46 	Matrix_type *nonconst_rhs=const_cast<Matrix_type*>(rhs); // The interface of nmlp is const-inconsistent.
     47 	int row=nonconst_rhs->getNumberOfRows(), column=nonconst_rhs->getNumberOfColumns();
     48 	ar << row << column;
     49 }
     51 template<class Archive>
     52 void save_construct_data(Archive &ar, CPUMatrix const *rhs, unsigned const version)
     53 	{ save_construct_matrix_data(ar, rhs, version); }
     55 template<class Archive>
     56 void save(Archive &ar, CPUMatrix const &rhs, unsigned const version){
     57 	boost::serialization::void_cast_register<CPUMatrix, Matrix>();
     58 	save_dense_matrix(ar, rhs, version);
     59 }
     61 template<class Archive>
     62 void load_construct_data(Archive &ar, CPUMatrix *rhs, unsigned const version)
     63 	{ load_construct_matrix_data(ar, rhs, version); }
     65 template<class Archive>
     66 void load(Archive &ar, CPUMatrix &rhs, unsigned const version){
     67 	boost::serialization::void_cast_register<CPUMatrix, Matrix>();
     68 	load_dense_matrix(ar, rhs, version);
     69 }
     72 template<class Archive>
     73 void save_construct_data(Archive &ar, GPUMatrix const *rhs, unsigned const version)
     74 	{ save_construct_matrix_data(ar, rhs, version); }
     76 template<class Archive>
     77 void save(Archive &ar, GPUMatrix const &rhs, unsigned const version){
     78 	boost::serialization::void_cast_register<GPUMatrix, Matrix>();
     79 	save_dense_matrix(ar, rhs, version);
     80 }
     82 template<class Archive>
     83 void load_construct_data(Archive &ar, GPUMatrix *rhs, unsigned const version)
     84 	{ load_construct_matrix_data(ar, rhs, version); }
     86 template<class Archive>
     87 void load(Archive &ar, GPUMatrix &rhs, unsigned const version){
     88 	boost::serialization::void_cast_register<GPUMatrix, Matrix>();
     89 	load_dense_matrix(ar, rhs, version);
     90 }
     92 template<class Archive>
     93 void save_construct_data(Archive &ar, CPUSparseMatrix const *rhs, unsigned const version)
     94 	{ save_construct_matrix_data(ar, rhs, version); }
     96 template<class Archive>
     97 void save(Archive &ar, CPUSparseMatrix const &rhs, unsigned const){
     98 	boost::serialization::void_cast_register<CPUSparseMatrix, Matrix>();
     99 	CPUSparseMatrix &nonconst_rhs=const_cast<CPUSparseMatrix&>(rhs); // The interface of nmlp is const-inconsistent.
    100 	int row, col, end_of_data=-1;
    101 	float val;
    102 	// This is the only way to do it with nmlp...
    103 	for(nonconst_rhs.initIterator(), nonconst_rhs.nextIterator(&row, &col, &val); nonconst_rhs.hasNextIterator(); nonconst_rhs.nextIterator(&row, &col, &val))
    104 		if(val)
    105 			ar << row << col << val;
    106 	if(val)
    107 		ar << row << col << val;
    108 	ar << end_of_data;
    109 }
    111 template<class Archive>
    112 void load_construct_data(Archive &ar, CPUSparseMatrix *rhs, unsigned const version)
    113 	{ load_construct_matrix_data(ar, rhs, version); }
    115 template<class Archive>
    116 void load(Archive &ar, CPUSparseMatrix &rhs, unsigned const){
    117 	boost::serialization::void_cast_register<CPUSparseMatrix, Matrix>();
    118 	int row, col;
    119 	float val;
    120 	while(ar >> row, row!=-1){
    121 		ar >> col >> val;
    122 		rhs.setValue(row, col, val);
    123 	}
    124 }
    126 template<class Archive>
    127 void save_construct_data(Archive &ar, Tensor const *rhs, unsigned const version){
    128 	int size=const_cast<Tensor*>(rhs)->getNumberOfMatrices();
    129 	ar << size;
    130 }
    132 template<class Archive>
    133 void save(Archive &ar, Tensor const &rhs, unsigned const){
    134 	Tensor &nonconst_rhs=const_cast<Tensor&>(rhs); // The interface of nmlp is const-inconsistent.
    135 	for(std::size_t i=0; i<nonconst_rhs.getNumberOfMatrices(); ++i){
    136 		boost::shared_ptr<Matrix> matrix=nonconst_rhs.getMatrix(i);
    137 		ar << matrix;
    138 	}
    139 }
    141 template<class Archive>
    142 void load_construct_data(Archive &ar, Tensor *rhs, unsigned const version){
    143 	int n;
    144 	ar >> n;
    145 	new(rhs) Tensor(n);
    146 }
    148 template<class Archive>
    149 void load(Archive &ar, Tensor &rhs, unsigned const){
    150 	for(std::size_t i=0; i<rhs.getNumberOfMatrices(); ++i){
    151 		boost::shared_ptr<Matrix> m;
    152 		ar >> m;
    153 		rhs.setMatrix(i, m);
    154 	}
    155 }