File indexing completed on 2024-04-06 12:13:10
0001
0002
0003
0004
0005
0006
0007
0008 #include <cassert>
0009 #include <cstdint>
0010 #include <iostream>
0011 #include <string>
0012 #include <cmath>
0013 #include <cppunit/extensions/HelperMacros.h>
0014 #include "FWCore/SOA/interface/Table.h"
0015 #include "FWCore/SOA/interface/TableView.h"
0016 #include "FWCore/SOA/interface/RowView.h"
0017 #include "FWCore/SOA/interface/Column.h"
0018 #include "FWCore/SOA/interface/TableItr.h"
0019 #include "FWCore/SOA/interface/TableExaminer.h"
0020
0021 class testTable : public CppUnit::TestFixture {
0022 CPPUNIT_TEST_SUITE(testTable);
0023
0024 CPPUNIT_TEST(rowviewCtrTest);
0025 CPPUNIT_TEST(rawTableItrTest);
0026 CPPUNIT_TEST(tableCtrTest);
0027 CPPUNIT_TEST(tableStandardOpsTest);
0028 CPPUNIT_TEST(tableColumnTest);
0029 CPPUNIT_TEST(tableViewConversionTest);
0030 CPPUNIT_TEST(tableExaminerTest);
0031 CPPUNIT_TEST(tableResizeTest);
0032 CPPUNIT_TEST(mutabilityTest);
0033 CPPUNIT_TEST_SUITE_END();
0034
0035 public:
0036 void setUp() {}
0037 void tearDown() {}
0038
0039 void rowviewCtrTest();
0040 void rawTableItrTest();
0041 void tableCtrTest();
0042 void tableStandardOpsTest();
0043 void tableColumnTest();
0044 void tableViewConversionTest();
0045 void tableExaminerTest();
0046 void tableResizeTest();
0047 void mutabilityTest();
0048 };
0049
0050 namespace ts {
0051 struct Eta : public edm::soa::Column<float, Eta> {
0052 static constexpr const char* const kLabel = "eta";
0053 };
0054
0055 SOA_DECLARE_COLUMN(Phi, float, "phi");
0056 SOA_DECLARE_COLUMN(Energy, float, "energy");
0057 SOA_DECLARE_COLUMN(ID, int, "id");
0058 SOA_DECLARE_COLUMN(Label, std::string, "label");
0059
0060 SOA_DECLARE_COLUMN(Px, double, "p_x");
0061 SOA_DECLARE_COLUMN(Py, double, "p_y");
0062 SOA_DECLARE_COLUMN(Pz, double, "p_z");
0063
0064 using ParticleTable = edm::soa::Table<Px, Py, Pz, Energy>;
0065
0066 using JetTable = edm::soa::Table<Eta, Phi>;
0067
0068
0069 using MyJetTable = edm::soa::AddColumns_t<JetTable, std::tuple<Label>>;
0070
0071
0072 using MyOtherJetTable = edm::soa::RemoveColumn_t<MyJetTable, Phi>;
0073 }
0074
0075
0076 CPPUNIT_TEST_SUITE_REGISTRATION(testTable);
0077
0078 void testTable::rowviewCtrTest() {
0079 int id = 1;
0080 float eta = 3.14;
0081 float phi = 1.5;
0082 std::string label{"foo"};
0083
0084 std::array<void const*, 4> variables{{&id, &eta, &phi, &label}};
0085
0086 edm::soa::RowView<ts::ID, ts::Eta, ts::Phi, ts::Label> rv{variables};
0087
0088 CPPUNIT_ASSERT(rv.get<ts::ID>() == id);
0089 CPPUNIT_ASSERT(rv.get<ts::Eta>() == eta);
0090 CPPUNIT_ASSERT(rv.get<ts::Phi>() == phi);
0091 CPPUNIT_ASSERT(rv.get<ts::Label>() == label);
0092 }
0093
0094 void testTable::rawTableItrTest() {
0095 int ids[] = {1, 2, 3};
0096 float etas[] = {3.14, 2.5, 0.3};
0097 float phis[] = {1.5, -0.2, 2.9};
0098
0099 std::array<void*, 3> variables{{ids, etas, phis}};
0100
0101 edm::soa::TableItr<ts::ID, ts::Eta, ts::Phi> itr{variables};
0102
0103 for (unsigned int i = 0; i < std::size(ids); ++i, ++itr) {
0104 auto v = *itr;
0105 CPPUNIT_ASSERT(v.get<ts::ID>() == ids[i]);
0106 CPPUNIT_ASSERT(v.get<ts::Eta>() == etas[i]);
0107 CPPUNIT_ASSERT(v.get<ts::Phi>() == phis[i]);
0108 }
0109 }
0110
0111 namespace {
0112 std::vector<double> pTs(edm::soa::TableView<ts::Px, ts::Py> tv) {
0113 std::vector<double> results;
0114 results.reserve(tv.size());
0115
0116 for (auto const& r : tv) {
0117 auto px = r.get<ts::Px>();
0118 auto py = r.get<ts::Py>();
0119 results.push_back(std::sqrt(px * px + py * py));
0120 }
0121
0122 return results;
0123 }
0124
0125 template <typename C>
0126 void compareEta(edm::soa::TableView<ts::Eta> iEtas, C const& iContainer) {
0127 auto it = iContainer.begin();
0128 for (auto v : iEtas.column<ts::Eta>()) {
0129 CPPUNIT_ASSERT(v == *it);
0130 ++it;
0131 }
0132 }
0133
0134 struct JetType {
0135 double eta_;
0136 double phi_;
0137 };
0138
0139 double value_for_column(JetType const& iJ, ts::Eta*) { return iJ.eta_; }
0140
0141 double value_for_column(JetType const& iJ, ts::Phi*) { return iJ.phi_; }
0142
0143 bool tolerance(double a, double b) { return std::abs(a - b) < 10E-6; }
0144 }
0145
0146 void testTable::tableColumnTest() {
0147 using namespace ts;
0148 using namespace edm::soa;
0149 std::array<double, 3> eta = {{1., 2., 4.}};
0150 std::array<double, 3> phi = {{3.14, 0., 1.3}};
0151
0152 JetTable jets{eta, phi};
0153 {
0154 auto it = phi.begin();
0155 for (auto v : jets.column<Phi>()) {
0156 CPPUNIT_ASSERT(tolerance(*it, v));
0157 ++it;
0158 }
0159 }
0160
0161 {
0162 auto it = eta.begin();
0163 for (auto v : jets.column<Eta>()) {
0164 CPPUNIT_ASSERT(tolerance(*it, v));
0165 ++it;
0166 }
0167 }
0168 }
0169
0170 void testTable::tableViewConversionTest() {
0171 using namespace ts;
0172 using namespace edm::soa;
0173 std::array<double, 3> eta = {{1., 2., 4.}};
0174 std::array<double, 3> phi = {{3.14, 0., 1.3}};
0175
0176 JetTable jets{eta, phi};
0177
0178 compareEta(jets, eta);
0179
0180 {
0181 TableView<Phi, Eta> view{jets};
0182 auto itEta = eta.cbegin();
0183 auto itPhi = phi.cbegin();
0184 for (auto const& v : view) {
0185 CPPUNIT_ASSERT(tolerance(*itEta, v.get<Eta>()));
0186 CPPUNIT_ASSERT(tolerance(*itPhi, v.get<Phi>()));
0187 ++itEta;
0188 ++itPhi;
0189 }
0190 }
0191
0192 std::vector<double> px = {0.1, 0.9, 1.3};
0193 std::vector<double> py = {0.8, 1.7, 2.1};
0194 std::vector<double> pz = {0.4, 1.0, 0.7};
0195 std::vector<double> energy = {1.4, 3.7, 4.1};
0196
0197 ParticleTable particles{px, py, pz, energy};
0198
0199 {
0200 std::vector<double> ptCompare;
0201 ptCompare.reserve(px.size());
0202 for (unsigned int i = 0; i < px.size(); ++i) {
0203 ptCompare.push_back(sqrt(px[i] * px[i] + py[i] * py[i]));
0204 }
0205 auto it = ptCompare.begin();
0206 for (auto v : pTs(particles)) {
0207 CPPUNIT_ASSERT(tolerance(*it, v));
0208 ++it;
0209 }
0210 }
0211 }
0212
0213 void testTable::tableCtrTest() {
0214 using namespace ts;
0215 using namespace edm::soa;
0216 std::array<double, 3> eta = {{1., 2., 4.}};
0217 std::array<double, 3> phi = {{3.14, 0., 1.3}};
0218
0219 JetTable jets{eta, phi};
0220
0221 {
0222 auto itEta = eta.begin();
0223 auto itPhi = phi.begin();
0224 for (auto const& v : jets) {
0225 CPPUNIT_ASSERT(tolerance(*itEta, v.get<Eta>()));
0226 CPPUNIT_ASSERT(tolerance(*itPhi, v.get<Phi>()));
0227 ++itEta;
0228 ++itPhi;
0229 }
0230 }
0231 std::vector<double> px = {0.1, 0.9, 1.3};
0232 std::vector<double> py = {0.8, 1.7, 2.1};
0233 std::vector<double> pz = {0.4, 1.0, 0.7};
0234 std::vector<double> energy = {1.4, 3.7, 4.1};
0235
0236 ParticleTable particles{px, py, pz, energy};
0237
0238 {
0239 std::vector<JetType> j = {{1., 3.14}, {2., 0.}, {4., 1.3}};
0240 std::vector<std::string> labels = {{"jet0", "jet1", "jet2"}};
0241
0242 int index = 0;
0243 MyJetTable jt{j, column_fillers(Label::filler([&index](JetType const&) {
0244 std::ostringstream s;
0245 s << "jet" << index++;
0246 return s.str();
0247 }))};
0248 auto itJ = j.begin();
0249 auto itLabels = labels.begin();
0250 for (auto const& v : jt) {
0251 CPPUNIT_ASSERT(tolerance(itJ->eta_, v.get<Eta>()));
0252 CPPUNIT_ASSERT(tolerance(itJ->eta_, v.get<Eta>()));
0253 CPPUNIT_ASSERT(v.get<Label>() == *itLabels);
0254 ++itJ;
0255 ++itLabels;
0256 }
0257
0258 {
0259 auto itFillIndex = labels.begin();
0260 MyJetTable jt{j, column_fillers(Label::filler([&itFillIndex](JetType const&) { return *(itFillIndex++); }))};
0261 auto itLabels = labels.begin();
0262 for (auto const& v : jt) {
0263 CPPUNIT_ASSERT(v.get<Label>() == *itLabels);
0264 ++itLabels;
0265 }
0266 }
0267 }
0268 }
0269 void testTable::tableStandardOpsTest() {
0270 using namespace ts;
0271 using namespace edm::soa;
0272
0273 std::vector<double> px = {0.1, 0.9, 1.3};
0274 std::vector<double> py = {0.8, 1.7, 2.1};
0275 std::vector<double> pz = {0.4, 1.0, 0.7};
0276 std::vector<double> energy = {1.4, 3.7, 4.1};
0277
0278 ParticleTable particles{px, py, pz, energy};
0279
0280 {
0281 ParticleTable copyTable{particles};
0282
0283 auto compare = [](const ParticleTable& iLHS, const ParticleTable& iRHS) {
0284 CPPUNIT_ASSERT(iLHS.size() == iRHS.size());
0285 for (size_t i = 0; i < iRHS.size(); ++i) {
0286 CPPUNIT_ASSERT(iLHS.get<Px>(i) == iRHS.get<Px>(i));
0287 CPPUNIT_ASSERT(iLHS.get<Py>(i) == iRHS.get<Py>(i));
0288 CPPUNIT_ASSERT(iLHS.get<Pz>(i) == iRHS.get<Pz>(i));
0289 CPPUNIT_ASSERT(iLHS.get<Energy>(i) == iRHS.get<Energy>(i));
0290 }
0291 };
0292 compare(copyTable, particles);
0293
0294 ParticleTable moveTable(std::move(copyTable));
0295 compare(moveTable, particles);
0296
0297 ParticleTable opEqTable;
0298 opEqTable = particles;
0299 compare(opEqTable, particles);
0300
0301 ParticleTable opEqMvTable;
0302 opEqMvTable = std::move(moveTable);
0303 compare(opEqMvTable, particles);
0304 }
0305 }
0306
0307 namespace {
0308 void checkColumnTypes(edm::soa::TableExaminerBase& reader) {
0309 auto columns = reader.columnTypes();
0310 std::array<std::type_index, 2> const types{{typeid(ts::Eta), typeid(ts::Phi)}};
0311
0312 auto itT = types.begin();
0313 for (auto c : columns) {
0314 CPPUNIT_ASSERT(c == *itT);
0315 ++itT;
0316 }
0317 };
0318
0319 void checkColumnDescriptions(edm::soa::TableExaminerBase& reader) {
0320 auto columns = reader.columnDescriptions();
0321
0322 std::array<std::string, 2> const desc{{"eta", "phi"}};
0323 std::array<std::type_index, 2> const types{{typeid(float), typeid(float)}};
0324
0325 auto itD = desc.begin();
0326 auto itT = types.begin();
0327
0328 for (auto c : columns) {
0329 CPPUNIT_ASSERT(c.first == *itD);
0330 CPPUNIT_ASSERT(c.second == *itT);
0331 ++itD;
0332 ++itT;
0333 }
0334 };
0335
0336 }
0337
0338 void testTable::tableExaminerTest() {
0339 using namespace edm::soa;
0340 using namespace ts;
0341
0342 std::array<double, 3> eta = {{1., 2., 4.}};
0343 std::array<double, 3> phi = {{3.14, 0., 1.3}};
0344 int size = eta.size();
0345 CPPUNIT_ASSERT(size == 3);
0346
0347 JetTable jets{eta, phi};
0348
0349 TableExaminer<JetTable> r(&jets);
0350 checkColumnTypes(r);
0351 checkColumnDescriptions(r);
0352 }
0353
0354 void testTable::tableResizeTest() {
0355 using namespace edm::soa;
0356 using namespace ts;
0357
0358 std::vector<double> px = {0.1, 0.9, 1.3};
0359 std::vector<double> py = {0.8, 1.7, 2.1};
0360 std::vector<double> pz = {0.4, 1.0, 0.7};
0361 std::vector<double> energy = {1.4, 3.7, 4.1};
0362
0363 ParticleTable particlesStandard{px, py, pz, energy};
0364
0365 ParticleTable particles{px, py, pz, energy};
0366
0367 particles.resize(2);
0368
0369 auto compare = [](const ParticleTable& iLHS, const ParticleTable& iRHS, size_t n) {
0370 for (size_t i = 0; i < n; ++i) {
0371 CPPUNIT_ASSERT(iLHS.get<Px>(i) == iRHS.get<Px>(i));
0372 CPPUNIT_ASSERT(iLHS.get<Py>(i) == iRHS.get<Py>(i));
0373 CPPUNIT_ASSERT(iLHS.get<Pz>(i) == iRHS.get<Pz>(i));
0374 CPPUNIT_ASSERT(iLHS.get<Energy>(i) == iRHS.get<Energy>(i));
0375 }
0376 };
0377
0378 CPPUNIT_ASSERT(particles.size() == 2);
0379 compare(particlesStandard, particles, 2);
0380
0381 particles.resize(4);
0382 CPPUNIT_ASSERT(particles.size() == 4);
0383 compare(particles, particlesStandard, 2);
0384
0385 for (size_t i = 2; i < 4; ++i) {
0386 CPPUNIT_ASSERT(particles.get<Px>(i) == 0.);
0387 CPPUNIT_ASSERT(particles.get<Py>(i) == 0.);
0388 CPPUNIT_ASSERT(particles.get<Pz>(i) == 0.);
0389 CPPUNIT_ASSERT(particles.get<Energy>(i) == 0.);
0390 }
0391 }
0392
0393 void testTable::mutabilityTest() {
0394 using namespace edm::soa;
0395 using namespace ts;
0396
0397 std::array<double, 3> eta = {{1., 2., 4.}};
0398 std::array<double, 3> phi = {{3.14, 0., 1.3}};
0399 JetTable jets{eta, phi};
0400
0401 jets.get<Eta>(0) = 0.;
0402 CPPUNIT_ASSERT(jets.get<Eta>(0) == 0.);
0403 jets.get<Phi>(1) = 0.03;
0404 CPPUNIT_ASSERT(tolerance(jets.get<Phi>(1), 0.03));
0405
0406 auto row = jets.row(2);
0407 CPPUNIT_ASSERT(row.get<Eta>() == 4.);
0408 CPPUNIT_ASSERT(tolerance(row.get<Phi>(), 1.3));
0409
0410 row.copyValuesFrom(JetType{5., 6.});
0411 CPPUNIT_ASSERT(row.get<Eta>() == 5.);
0412 CPPUNIT_ASSERT(row.get<Phi>() == 6.);
0413
0414 row.copyValuesFrom(JetType{7., 8.}, column_fillers(Phi::filler([](JetType const&) { return 9.; })));
0415 CPPUNIT_ASSERT(row.get<Eta>() == 7.);
0416 CPPUNIT_ASSERT(row.get<Phi>() == 9.);
0417
0418 row.set<Phi>(10.).set<Eta>(11.);
0419 CPPUNIT_ASSERT(row.get<Eta>() == 11.);
0420 CPPUNIT_ASSERT(row.get<Phi>() == 10.);
0421 }
0422
0423 #include <Utilities/Testing/interface/CppUnit_testdriver.icpp>