diff --git a/include/xtensor/xstrided_view.hpp b/include/xtensor/xstrided_view.hpp index a65f9467a..63b3f34bc 100644 --- a/include/xtensor/xstrided_view.hpp +++ b/include/xtensor/xstrided_view.hpp @@ -153,10 +153,36 @@ namespace xt using inner_storage_type = typename base_type::inner_storage_type; using storage_type = typename base_type::storage_type; - using linear_iterator = typename storage_type::iterator; - using const_linear_iterator = typename storage_type::const_iterator; - using reverse_linear_iterator = std::reverse_iterator; - using const_reverse_linear_iterator = std::reverse_iterator; + + template > + struct get_linear_iterator : std::false_type + { + using iterator = typename C::iterator; + }; + + template + struct get_linear_iterator().linear_begin())>> : std::true_type + { + using iterator = typename C::linear_iterator; + }; + + template > + struct get_const_linear_iterator : std::false_type + { + using iterator = typename C::const_iterator; + }; + + template + struct get_const_linear_iterator().linear_cbegin())>> : std::true_type + { + using iterator = typename C::const_linear_iterator; + }; + + using linear_iterator = typename get_linear_iterator::iterator; + using const_linear_iterator = typename get_const_linear_iterator::iterator; + using reverse_linear_iterator = std::reverse_iterator::iterator>; + using const_reverse_linear_iterator = std::reverse_iterator< + typename get_const_linear_iterator::iterator>; using iterable_base = select_iterable_base_t; using inner_shape_type = typename base_type::inner_shape_type; @@ -222,6 +248,7 @@ namespace xt const_linear_iterator linear_begin() const; const_linear_iterator linear_end() const; const_linear_iterator linear_cbegin() const; + const_linear_iterator linear_cend() const; reverse_linear_iterator linear_rbegin(); @@ -511,7 +538,16 @@ namespace xt template inline auto xstrided_view::linear_cbegin() const -> const_linear_iterator { - return this->storage().cbegin() + static_cast(data_offset()); + return xtl::mpl::static_if::value>( + [&](auto self) + { + return self(this->storage()).linear_cbegin() + static_cast(data_offset()); + }, + [&](auto self) + { + return self(this->storage()).cbegin() + static_cast(data_offset()); + } + ); } template diff --git a/include/xtensor/xstrided_view_base.hpp b/include/xtensor/xstrided_view_base.hpp index 9f3f8ef4c..020a8add3 100644 --- a/include/xtensor/xstrided_view_base.hpp +++ b/include/xtensor/xstrided_view_base.hpp @@ -48,6 +48,7 @@ namespace xt using reverse_iterator = decltype(std::declval>().template rbegin()); using const_reverse_iterator = decltype(std::declval>().template crbegin()); + explicit flat_expression_adaptor(CT* e); template @@ -75,6 +76,52 @@ namespace xt size_type m_size; }; + template + class linear_flat_expression_adaptor : public flat_expression_adaptor + { + public: + + using xexpression_type = std::decay_t; + using shape_type = typename xexpression_type::shape_type; + using inner_strides_type = get_strides_t; + using index_type = inner_strides_type; + using size_type = typename xexpression_type::size_type; + using value_type = typename xexpression_type::value_type; + using const_reference = typename xexpression_type::const_reference; + using reference = std::conditional_t< + std::is_const>::value, + typename xexpression_type::const_reference, + typename xexpression_type::reference>; + + + using linear_iterator = decltype(std::declval>().linear_begin()); + using const_linear_iterator = decltype(std::declval>().linear_cbegin()); + using reverse_linear_iterator = decltype(std::declval>().linear_rbegin() + ); + using const_reverse_linear_iterator = decltype(std::declval>().linear_crbegin()); + + + explicit linear_flat_expression_adaptor(CT* e); + + template + linear_flat_expression_adaptor(CT* e, FST&& strides); + + linear_iterator linear_begin(); + linear_iterator linear_end(); + const_linear_iterator linear_begin() const; + const_linear_iterator linear_end() const; + const_linear_iterator linear_cbegin() const; + const_linear_iterator linear_cend() const; + + private: + + static index_type& get_index(); + + mutable CT* m_e; + inner_strides_type m_strides; + size_type m_size; + }; + template struct is_flat_expression_adaptor : std::false_type { @@ -85,9 +132,21 @@ namespace xt { }; + template + struct is_linear_flat_expression_adaptor : std::false_type + { + }; + + template + struct is_linear_flat_expression_adaptor> : std::true_type + { + }; + template - struct provides_data_interface - : xtl::conjunction>, xtl::negation>> + struct provides_data_interface : xtl::conjunction< + has_data_interface>, + xtl::negation>, + xtl::negation>> { }; } @@ -246,7 +305,11 @@ namespace xt template struct flat_adaptor_getter { - using type = flat_expression_adaptor, L>; + using type = std::conditional_t< + detail::has_linear_iterator>::value + && (std::remove_reference_t::static_layout == L), + linear_flat_expression_adaptor, L>, + flat_expression_adaptor, L>>; using reference = std::add_lvalue_reference_t; template @@ -318,9 +381,7 @@ namespace xt layout_type layout ) noexcept : m_e(std::forward(e)) - , - // m_storage(detail::get_flat_storage(m_e)), - m_storage(storage_getter::get_flat_storage(m_e)) + , m_storage(storage_getter::get_flat_storage(m_e)) , m_shape(std::forward(shape)) , m_strides(std::move(strides)) , m_offset(offset) @@ -345,6 +406,14 @@ namespace xt new_storage.update_pointer(std::addressof(expr)); return new_storage; } + + template + auto copy_move_storage(T& expr, const detail::linear_flat_expression_adaptor& storage) + { + detail::linear_flat_expression_adaptor new_storage = storage; // copy storage + new_storage.update_pointer(std::addressof(expr)); + return new_storage; + } } template @@ -783,6 +852,58 @@ namespace xt thread_local static index_type index; return index; } + + template + inline linear_flat_expression_adaptor::linear_flat_expression_adaptor(CT* e) + : flat_expression_adaptor(e) + , m_e(e) + { + } + + template + template + inline linear_flat_expression_adaptor::linear_flat_expression_adaptor(CT* e, FST&& strides) + : flat_expression_adaptor(e, strides) + , m_e(e) + , m_strides(xtl::forward_sequence(strides)) + { + } + + template + inline auto linear_flat_expression_adaptor::linear_begin() -> linear_iterator + { + return m_e->linear_begin(); + } + + template + inline auto linear_flat_expression_adaptor::linear_end() -> linear_iterator + { + return m_e->linear_end(); + } + + template + inline auto linear_flat_expression_adaptor::linear_begin() const -> const_linear_iterator + { + return m_e->linear_cbegin(); + } + + template + inline auto linear_flat_expression_adaptor::linear_end() const -> const_linear_iterator + { + return m_e->linear_cend(); + } + + template + inline auto linear_flat_expression_adaptor::linear_cbegin() const -> const_linear_iterator + { + return m_e->linear_cbegin(); + } + + template + inline auto linear_flat_expression_adaptor::linear_cend() const -> const_linear_iterator + { + return m_e->linear_cend(); + } } /**********************************