aboutsummaryrefslogtreecommitdiff
path: root/trie.cc
diff options
context:
space:
mode:
Diffstat (limited to 'trie.cc')
-rw-r--r--trie.cc38
1 files changed, 15 insertions, 23 deletions
diff --git a/trie.cc b/trie.cc
index 15a74c9..3338a8c 100644
--- a/trie.cc
+++ b/trie.cc
@@ -1,7 +1,6 @@
#include <cstdint>
#include <forward_list>
#include <map>
-#include <memory>
#include <cassert>
#include <iostream>
@@ -11,8 +10,6 @@ template <
>
class Trie {
public:
- typedef std::unique_ptr<Trie> ptr;
-
Trie():
children_() { }
@@ -23,22 +20,17 @@ class Trie {
inline void add(std::forward_list<Key>& path,
typename std::forward_list<Key>::const_iterator curr) {
if ( curr != path.end() ) {
- Trie::ptr& trie = this->children_[*curr];
+ Trie& trie = this->children_[*curr];
- if ( trie ) {
- trie->add(path, ++curr);
- } else {
- trie.reset(new Trie<Key>());
- trie->add(path, ++curr);
- }
+ trie.add(path, ++curr);
}
}
- inline Trie* resolve(std::forward_list<Key> path) const {
+ inline const Trie* resolve(std::forward_list<Key> path) const {
return this->resolve(path, path.begin());
}
- inline Trie* resolve(
+ inline const Trie* resolve(
std::forward_list<Key>& path,
typename std::forward_list<Key>::const_iterator curr
) const {
@@ -48,29 +40,29 @@ class Trie {
auto next = ++curr;
if ( next == path.end() ) {
- return (*trie).second.get();
+ return &(*trie).second;
} else {
- return (*trie).second->resolve(path, next);
+ return (*trie).second.resolve(path, next);
}
} else {
return nullptr;
}
}
- protected:
- std::map<Key, Trie::ptr> children_;
+ private:
+ std::map<Key, Trie> children_;
};
int main() {
Trie<uint8_t> test;
- test.add(std::forward_list<uint8_t>{1, 2, 3});
- test.add(std::forward_list<uint8_t>{1, 2, 4});
- test.add(std::forward_list<uint8_t>{2, 1});
- test.add(std::forward_list<uint8_t>{2, 1, 1});
+ test.add({1, 2, 3});
+ test.add({1, 2, 4});
+ test.add({2, 1});
+ test.add({2, 1, 1});
- assert(test.resolve(std::forward_list<uint8_t>{1, 2}) != nullptr);
- assert(test.resolve(std::forward_list<uint8_t>{1, 2, 4}) != nullptr);
- assert(test.resolve(std::forward_list<uint8_t>{3}) == nullptr);
+ assert(test.resolve({1, 2}) != nullptr);
+ assert(test.resolve({1, 2, 4}) != nullptr);
+ assert(test.resolve({3}) == nullptr);
}