diff --git a/source/simple/geom/vector.hpp b/source/simple/geom/vector.hpp index 210c13ecd6e9ef1fc03e2cf18d5a3302bc167003..314cc7c5a3413e99ac76025d058ab97976fcad73 100644 --- a/source/simple/geom/vector.hpp +++ b/source/simple/geom/vector.hpp @@ -401,7 +401,7 @@ namespace simple::geom } template <typename C = typename meta::coordinate_type, - std::enable_if_t<std::is_same_v<C,detail::disjunctive_bool>>* = nullptr> + std::enable_if_t<std::is_same_v<C,detail::disjunctive_bool>>* = nullptr> [[nodiscard]] constexpr explicit operator bool() const noexcept { @@ -669,10 +669,22 @@ SIMPLE_GEOM_VECTOR_DEFINE_COMPARISON_OPERATOR(<=, bool_vector) template <typename T> constexpr static bool can_apply = can_apply_s<T>::value; + // TODO: common declval code between this and can_apply + template <typename T, typename = std::nullptr_t> + struct product_result_s { using type = Coordinate; }; + template <typename T> + struct product_result_s<T, decltype(void(std::declval<vector>()(std::declval<T>())), nullptr)> { using type = decltype(std::declval<vector>()(std::declval<T>())); }; + template <typename T> + using product_result = typename product_result_s<T>::type; + // matrix multiplication and matrix-vector multiplication/dot product fusion mutant operator template<typename AnotherComponent, size_t AnotherDimesnions, typename AnotherOrder, std::enable_if_t<std::is_same_v<Order,AnotherOrder> || can_apply<AnotherComponent>>* = nullptr, - typename Return = std::conditional_t<can_apply<AnotherComponent>, vector<Coordinate, AnotherDimesnions, AnotherOrder>, Coordinate>> + typename Return = std::conditional_t<can_apply<AnotherComponent>, + vector<product_result<AnotherComponent>, AnotherDimesnions, AnotherOrder>, + Coordinate + > + > [[nodiscard]] constexpr Return operator()(const vector<AnotherComponent, AnotherDimesnions, AnotherOrder> & another) const { @@ -906,14 +918,14 @@ SIMPLE_GEOM_VECTOR_DEFINE_COMPARISON_OPERATOR(<=, bool_vector) std::ostream & operator<<(std::ostream & out, const vector<vector<Coordinate, N, O1>, M, O2> & vector) { for(size_t i = 0; i < N; ++i) - out << "---"; + out << "^"; out << "\n"; for(size_t i = 0; i < M; ++i) out << vector[i] << "\n"; for(size_t i = 0; i < N; ++i) - out << "---"; + out << "v"; out << "\n"; return out; diff --git a/unit_tests/sym.cpp b/unit_tests/sym.cpp index 477c76448aa78623f1114f244ea7f5b142f41bdf..79790302d7bd3d357dc45e09c5f20879654f905e 100644 --- a/unit_tests/sym.cpp +++ b/unit_tests/sym.cpp @@ -1,3 +1,4 @@ +#include <iostream> #include <string> #include <cassert> #include <sstream> @@ -91,21 +92,102 @@ const auto& make_rotor = dot_wedge<T, D, O>; void Explode() { - auto a = vector( sym_vector("e1", "e2", "e3") ); // 3x1 - auto b = sym_vector("e1", "e2", "e3").mutant_clone(wrap_in_vector); // 1x3 + auto a = vector( sym_vector("x1", "x2", "x3") ); // 3x1 + auto b = sym_vector("y1", "y2", "y3").mutant_clone(wrap_in_vector); // 1x3 + + ss << a(b); // 3x1 * 1x3 = 3x3 + + assert( ss.str() == +R"(^^^ +(x1*y1, x2*y1, x3*y1) +(x1*y2, x2*y2, x3*y2) +(x1*y3, x2*y3, x3*y3) +vvv +)" + ); ss.str(""); + + ss << b(a); // 1x3 * 3x1 = 1x1 + assert( ss.str() == +R"(^ +(y1*x1+y2*x2+y3*x3) +v +)" + ); ss.str(""); + + auto ab = vector(a(b)); // 3x3x1 - ss << ab(b) << '\n'; // 3x3x1 * 1x3 = 3x3x3 - // ss << (ab(b))(ab) << '\n'; // can't 3x3x3 * 3x3x1 = 3x3x3x1 - // ss << b(ab) << '\n'; // 1x3 * 3x3x1 = 1x3x1 - ss << a(b) << '\n'; - ss << b(a) << '\n'; + + ss << ab(b); // 3x3x1 * 1x3 = 3x3x3 + assert( ss.str() == +R"(^^^ +^^^ +(x1*y1*y1, x2*y1*y1, x3*y1*y1) +(x1*y2*y1, x2*y2*y1, x3*y2*y1) +(x1*y3*y1, x2*y3*y1, x3*y3*y1) +vvv + +^^^ +(x1*y1*y2, x2*y1*y2, x3*y1*y2) +(x1*y2*y2, x2*y2*y2, x3*y2*y2) +(x1*y3*y2, x2*y3*y2, x3*y3*y2) +vvv + +^^^ +(x1*y1*y3, x2*y1*y3, x3*y1*y3) +(x1*y2*y3, x2*y2*y3, x3*y2*y3) +(x1*y3*y3, x2*y3*y3, x3*y3*y3) +vvv + +vvv +)" + ); ss.str(""); + + ss << b(ab); // 1x3 * 3x3x1 = 1x3x1 + assert( ss.str() == +R"(^^^ +^ +(y1*x1*y1+y2*x2*y1+y3*x3*y1) +(y1*x1*y2+y2*x2*y2+y3*x3*y2) +(y1*x1*y3+y2*x2*y3+y3*x3*y3) +v + +vvv +)" + ); ss.str(""); + + ss << (ab(b))(ab); // brace yourselves: 3x3x3 * 3x3x1 = 3x3x3x1 + assert( ss.str() == +R"(^^^ +^^^ +^^^ +(x1*y1*y1*x1*y1+x1*y1*y2*x2*y1+x1*y1*y3*x3*y1, x2*y1*y1*x1*y1+x2*y1*y2*x2*y1+x2*y1*y3*x3*y1, x3*y1*y1*x1*y1+x3*y1*y2*x2*y1+x3*y1*y3*x3*y1) +(x1*y2*y1*x1*y1+x1*y2*y2*x2*y1+x1*y2*y3*x3*y1, x2*y2*y1*x1*y1+x2*y2*y2*x2*y1+x2*y2*y3*x3*y1, x3*y2*y1*x1*y1+x3*y2*y2*x2*y1+x3*y2*y3*x3*y1) +(x1*y3*y1*x1*y1+x1*y3*y2*x2*y1+x1*y3*y3*x3*y1, x2*y3*y1*x1*y1+x2*y3*y2*x2*y1+x2*y3*y3*x3*y1, x3*y3*y1*x1*y1+x3*y3*y2*x2*y1+x3*y3*y3*x3*y1) +vvv + +^^^ +(x1*y1*y1*x1*y2+x1*y1*y2*x2*y2+x1*y1*y3*x3*y2, x2*y1*y1*x1*y2+x2*y1*y2*x2*y2+x2*y1*y3*x3*y2, x3*y1*y1*x1*y2+x3*y1*y2*x2*y2+x3*y1*y3*x3*y2) +(x1*y2*y1*x1*y2+x1*y2*y2*x2*y2+x1*y2*y3*x3*y2, x2*y2*y1*x1*y2+x2*y2*y2*x2*y2+x2*y2*y3*x3*y2, x3*y2*y1*x1*y2+x3*y2*y2*x2*y2+x3*y2*y3*x3*y2) +(x1*y3*y1*x1*y2+x1*y3*y2*x2*y2+x1*y3*y3*x3*y2, x2*y3*y1*x1*y2+x2*y3*y2*x2*y2+x2*y3*y3*x3*y2, x3*y3*y1*x1*y2+x3*y3*y2*x2*y2+x3*y3*y3*x3*y2) +vvv + +^^^ +(x1*y1*y1*x1*y3+x1*y1*y2*x2*y3+x1*y1*y3*x3*y3, x2*y1*y1*x1*y3+x2*y1*y2*x2*y3+x2*y1*y3*x3*y3, x3*y1*y1*x1*y3+x3*y1*y2*x2*y3+x3*y1*y3*x3*y3) +(x1*y2*y1*x1*y3+x1*y2*y2*x2*y3+x1*y2*y3*x3*y3, x2*y2*y1*x1*y3+x2*y2*y2*x2*y3+x2*y2*y3*x3*y3, x3*y2*y1*x1*y3+x3*y2*y2*x2*y3+x3*y2*y3*x3*y3) +(x1*y3*y1*x1*y3+x1*y3*y2*x2*y3+x1*y3*y3*x3*y3, x2*y3*y1*x1*y3+x2*y3*y2*x2*y3+x2*y3*y3*x3*y3, x3*y3*y1*x1*y3+x3*y3*y2*x2*y3+x3*y3*y3*x3*y3) +vvv + +vvv + +vvv +)" + ); ss.str(""); + + auto a2 = vector( sym_vector("a", "b") ); // 2x1 auto b2 = sym_vector("c", "d").mutant_clone(wrap_in_vector); // 1x2 - // ss << a2(b2) << '\n'; - // ss << vector(a2(b2))(b2) << '\n'; ss << a2(b2) << '\n'; - // ss << b2(a2) << '\n'; ss << dot_wedge( sym_vector("a", "b"), @@ -123,47 +205,60 @@ void Explode() // = aca - acib - bda + bdib + adia + bcia - adiib - bciib = // = aca - bda + bdib + adia + adb + bcb -std::string expected = R"(--------- ---------- -(e1*e1*e1, e2*e1*e1, e3*e1*e1) -(e1*e2*e1, e2*e2*e1, e3*e2*e1) -(e1*e3*e1, e2*e3*e1, e3*e3*e1) ---------- - ---------- -(e1*e1*e2, e2*e1*e2, e3*e1*e2) -(e1*e2*e2, e2*e2*e2, e3*e2*e2) -(e1*e3*e2, e2*e3*e2, e3*e3*e2) ---------- - ---------- -(e1*e1*e3, e2*e1*e3, e3*e1*e3) -(e1*e2*e3, e2*e2*e3, e3*e2*e3) -(e1*e3*e3, e2*e3*e3, e3*e3*e3) ---------- - ---------- - ---------- -(e1*e1, e2*e1, e3*e1) -(e1*e2, e2*e2, e3*e2) -(e1*e3, e2*e3, e3*e3) ---------- - ---- -(e1*e1+e2*e2+e3*e3) ---- - ------- + assert(ss.str() == +R"(^^ (a*c, b*c) (a*d, b*d) ------- +vv (b*c-a*d, a*c+b*d) (u2*v1-u1*v2, u3*v1-u1*v3, u3*v2-u2*v3, u1*v1+u2*v2+u3*v3) -)"; - - assert(ss.str() == expected); +)" + ); ss.str(""); + + + auto aa = vector( + sym_vector("x1", "x2"), + sym_vector("x3", "x4") + ); + auto bb = vector( + sym_vector("y1", "y2"), + sym_vector("y3", "y4") + ); + auto A = vector( aa ); // 2x2x1 + auto B = bb.mutant_clone([](auto x){ return x.mutant_clone(wrap_in_vector);}); // 1x2x2 + ss << A(B); // 2x2x1 * 1x2x2 = 2x2x2x2 + assert(ss.str() == +R"(^^ +^^ +^^ +(x1*y1, x2*y1) +(x3*y1, x4*y1) +vv + +^^ +(x1*y2, x2*y2) +(x3*y2, x4*y2) +vv + +vv + +^^ +^^ +(x1*y3, x2*y3) +(x3*y3, x4*y3) +vv + +^^ +(x1*y4, x2*y4) +(x3*y4, x4*y4) +vv + +vv + +vv +)" + ); ss.str(""); }