// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. #ifndef FST_EXTENSIONS_LINEAR_TRIE_H_ #define FST_EXTENSIONS_LINEAR_TRIE_H_ #include #include #include #include #include #include #include #include #include #include namespace fst { inline constexpr int kNoTrieNodeId = -1; template class FlatTrieTopology; // Forward declarations of all available trie topologies. template class NestedTrieTopology; // A pair of parent node id and label, part of a trie edge template struct ParentLabel { int parent; L label; ParentLabel() = default; ParentLabel(int p, L l) : parent(p), label(l) {} bool operator==(const ParentLabel &that) const { return parent == that.parent && label == that.label; } std::istream &Read(std::istream &strm) { ReadType(strm, &parent); ReadType(strm, &label); return strm; } std::ostream &Write(std::ostream &strm) const { WriteType(strm, parent); WriteType(strm, label); return strm; } }; template struct ParentLabelHash { size_t operator()(const ParentLabel &pl) const { return static_cast(pl.parent * 7853 + H()(pl.label)); } }; // The trie topology in a nested tree of hash maps; allows efficient // iteration over children of a specific node. template class NestedTrieTopology { public: typedef L Label; typedef H Hash; typedef std::unordered_map NextMap; class const_iterator { public: typedef std::forward_iterator_tag iterator_category; typedef std::pair, int> value_type; typedef std::ptrdiff_t difference_type; typedef const value_type *pointer; typedef const value_type &reference; friend class NestedTrieTopology; const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {} reference operator*() { UpdateStub(); return stub_; } pointer operator->() { UpdateStub(); return &stub_; } const_iterator &operator++(); const_iterator &operator++(int); bool operator==(const const_iterator &that) const { return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ && cur_edge_ == that.cur_edge_; } bool operator!=(const const_iterator &that) const { return !(*this == that); } private: const_iterator(const NestedTrieTopology *ptr, int cur_node) : ptr_(ptr), cur_node_(cur_node) { SetProperCurEdge(); } void SetProperCurEdge() { if (cur_node_ < ptr_->NumNodes()) cur_edge_ = ptr_->nodes_[cur_node_]->begin(); else cur_edge_ = ptr_->nodes_[0]->begin(); } void UpdateStub() { stub_.first = ParentLabel(cur_node_, cur_edge_->first); stub_.second = cur_edge_->second; } const NestedTrieTopology *ptr_; int cur_node_; typename NextMap::const_iterator cur_edge_; value_type stub_; }; NestedTrieTopology(); NestedTrieTopology(const NestedTrieTopology &that); ~NestedTrieTopology() = default; void swap(NestedTrieTopology &that); NestedTrieTopology &operator=(const NestedTrieTopology &that); bool operator==(const NestedTrieTopology &that) const; bool operator!=(const NestedTrieTopology &that) const; int Root() const { return 0; } size_t NumNodes() const { return nodes_.size(); } int Insert(int parent, const L &label); int Find(int parent, const L &label) const; const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; } std::istream &Read(std::istream &strm); std::ostream &Write(std::ostream &strm) const; const_iterator begin() const { return const_iterator(this, 0); } const_iterator end() const { return const_iterator(this, NumNodes()); } private: // Use pointers to avoid copying the maps when the vector grows std::vector> nodes_; }; template NestedTrieTopology::NestedTrieTopology() { nodes_.push_back(std::make_unique()); } template NestedTrieTopology::NestedTrieTopology(const NestedTrieTopology &that) { nodes_.reserve(that.nodes_.size()); for (const auto &node : that.nodes_) { nodes_.push_back(std::make_unique(*node)); } } // TODO(wuke): std::swap compatibility template inline void NestedTrieTopology::swap(NestedTrieTopology &that) { nodes_.swap(that.nodes_); } template inline NestedTrieTopology &NestedTrieTopology::operator=( const NestedTrieTopology &that) { NestedTrieTopology copy(that); swap(copy); return *this; } template inline bool NestedTrieTopology::operator==( const NestedTrieTopology &that) const { if (NumNodes() != that.NumNodes()) return false; for (int i = 0; i < NumNodes(); ++i) if (ChildrenOf(i) != that.ChildrenOf(i)) return false; return true; } template inline bool NestedTrieTopology::operator!=( const NestedTrieTopology &that) const { return !(*this == that); } template inline int NestedTrieTopology::Insert(int parent, const L &label) { int ret = Find(parent, label); if (ret == kNoTrieNodeId) { ret = NumNodes(); (*nodes_[parent])[label] = ret; nodes_.push_back(std::make_unique()); } return ret; } template inline int NestedTrieTopology::Find(int parent, const L &label) const { typename NextMap::const_iterator it = nodes_[parent]->find(label); return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second; } template inline std::istream &NestedTrieTopology::Read(std::istream &strm) { NestedTrieTopology new_trie; size_t num_nodes; if (!ReadType(strm, &num_nodes)) return strm; for (size_t i = 1; i < num_nodes; ++i) { new_trie.nodes_.push_back(std::make_unique()); } for (size_t i = 0; i < num_nodes; ++i) { ReadType(strm, new_trie.nodes_[i].get()); } if (strm) swap(new_trie); return strm; } template inline std::ostream &NestedTrieTopology::Write(std::ostream &strm) const { WriteType(strm, NumNodes()); for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]); return strm; } template inline typename NestedTrieTopology::const_iterator & NestedTrieTopology::const_iterator::operator++() { ++cur_edge_; if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) { ++cur_node_; while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty()) ++cur_node_; SetProperCurEdge(); } return *this; } template inline typename NestedTrieTopology::const_iterator & NestedTrieTopology::const_iterator::operator++(int) { // NOLINT const_iterator save(*this); ++(*this); return save; } // The trie topology in a single hash map; only allows iteration over // all the edges in arbitrary order. template class FlatTrieTopology { private: typedef std::unordered_map, int, ParentLabelHash> NextMap; public: // Iterator over edges as std::pair, int> typedef typename NextMap::const_iterator const_iterator; typedef L Label; typedef H Hash; FlatTrieTopology() = default; FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {} template explicit FlatTrieTopology(const T &that); // TODO(wuke): std::swap compatibility void swap(FlatTrieTopology &that) { next_.swap(that.next_); } bool operator==(const FlatTrieTopology &that) const { return next_ == that.next_; } bool operator!=(const FlatTrieTopology &that) const { return !(*this == that); } int Root() const { return 0; } size_t NumNodes() const { return next_.size() + 1; } int Insert(int parent, const L &label); int Find(int parent, const L &label) const; std::istream &Read(std::istream &strm) { return ReadType(strm, &next_); } std::ostream &Write(std::ostream &strm) const { return WriteType(strm, next_); } const_iterator begin() const { return next_.begin(); } const_iterator end() const { return next_.end(); } private: NextMap next_; }; template template FlatTrieTopology::FlatTrieTopology(const T &that) : next_(that.begin(), that.end()) {} template inline int FlatTrieTopology::Insert(int parent, const L &label) { int ret = Find(parent, label); if (ret == kNoTrieNodeId) { ret = NumNodes(); next_[ParentLabel(parent, label)] = ret; } return ret; } template inline int FlatTrieTopology::Find(int parent, const L &label) const { typename NextMap::const_iterator it = next_.find(ParentLabel(parent, label)); return it == next_.end() ? kNoTrieNodeId : it->second; } // A collection of implementations of the trie data structure. The key // is a sequence of type `L` which must be hashable. The value is of // `V` which must be default constructible and copyable. In addition, // a value object is stored for each node in the trie therefore // copying `V` should be cheap. // // One can access the store values with an integer node id, using the // [] operator. A valid node id can be obtained by the following ways: // // 1. Using the `Root()` method to get the node id of the root. // // 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense // so every integer in this range is a valid node id. // // 3. Using the node id returned from a successful `Insert()` or // `Find()` call. // // 4. Iterating over the trie edges with an `EdgeIterator` and using // the node ids returned from its `Parent()` and `Child()` methods. // // Below is an example of inserting keys into the trie: // // const string words[] = {"hello", "health", "jello"}; // Trie dict; // for (auto word : words) { // int cur = dict.Root(); // for (char c : word) { // cur = dict.Insert(cur, c); // } // dict[cur] = true; // } // // And the following is an example of looking up the longest prefix of // a string using the trie constructed above: // // string query = "healed"; // size_t prefix_length = 0; // int cur = dict.Find(dict.Root(), query[prefix_length]); // while (prefix_length < query.size() && // cur != Trie::kNoNodeId) { // ++prefix_length; // cur = dict.Find(cur, query[prefix_length]); // } template class MutableTrie { public: template friend class MutableTrie; typedef L Label; typedef V Value; typedef T Topology; // Constructs a trie with only the root node. MutableTrie() = default; // Conversion from another trie of a possiblly different // topology. The underlying topology must supported conversion. template explicit MutableTrie(const MutableTrie &that) : topology_(that.topology_), values_(that.values_) {} // TODO(wuke): std::swap compatibility void swap(MutableTrie &that) { topology_.swap(that.topology_); values_.swap(that.values_); } int Root() const { return topology_.Root(); } size_t NumNodes() const { return topology_.NumNodes(); } // Inserts an edge with given `label` at node `parent`. Returns the // child node id. If the node already exists, returns the node id // right away. int Insert(int parent, const L &label) { int ret = topology_.Insert(parent, label); values_.resize(NumNodes()); return ret; } // Finds the node id of the node from `parent` via `label`. Returns // `kNoTrieNodeId` when such a node does not exist. int Find(int parent, const L &label) const { return topology_.Find(parent, label); } const T &TrieTopology() const { return topology_; } // Accesses the value stored for the given node. V &operator[](int node_id) { return values_[node_id]; } const V &operator[](int node_id) const { return values_[node_id]; } // Comparison by content bool operator==(const MutableTrie &that) const { return topology_ == that.topology_ && values_ == that.values_; } bool operator!=(const MutableTrie &that) const { return !(*this == that); } std::istream &Read(std::istream &strm) { ReadType(strm, &topology_); ReadType(strm, &values_); return strm; } std::ostream &Write(std::ostream &strm) const { WriteType(strm, topology_); WriteType(strm, values_); return strm; } private: T topology_; std::vector values_; }; } // namespace fst #endif // FST_EXTENSIONS_LINEAR_TRIE_H_