#ifndef _map_h
#define _map_h
#include <cstdlib>
#include "stack.h"
template <typename KeyType, typename ValueType>
class Map {
public:
Map();
virtual ~Map();
int size() const;
bool isEmpty() const;
void put(const KeyType & key, const ValueType & value);
ValueType get(const KeyType & key) const;
bool containsKey(const KeyType & key) const;
void remove(const KeyType & key);
void clear();
ValueType & operator[](const KeyType & key);
ValueType operator[](const KeyType & key) const;
std::string toString();
void mapAll(void (*fn)(KeyType, ValueType)) const;
void mapAll(void (*fn)(const KeyType &, const ValueType &)) const;
template <typename FunctorType>
void mapAll(FunctorType fn) const;
private:
static const int BST_LEFT_HEAVY = -1;
static const int BST_IN_BALANCE = 0;
static const int BST_RIGHT_HEAVY = +1;
struct BSTNode {
KeyType key;
ValueType value;
BSTNode *left;
BSTNode *right;
int bf;
};
class Comparator {
public:
virtual ~Comparator() { }
virtual bool lessThan(const KeyType & k1, const KeyType & k2) = 0;
virtual Comparator *clone() = 0;
};
template <typename CompareType>
class TemplateComparator : public Comparator {
public:
TemplateComparator(CompareType cmp) {
this->cmp = new CompareType(cmp);
}
~TemplateComparator() {
delete cmp;
}
virtual bool lessThan(const KeyType & k1, const KeyType & k2) {
return (*cmp)(k1, k2);
}
virtual Comparator *clone() {
return new TemplateComparator<CompareType>(*cmp);
}
private:
CompareType *cmp;
};
Comparator & getComparator() const {
return *cmpp;
}
BSTNode *root;
int nodeCount;
Comparator *cmpp;
ValueType *findNode(BSTNode *t, const KeyType & key) const {
if (t == NULL) return NULL;
int sign = compareKeys(key, t->key);
if (sign == 0) return &t->value;
if (sign < 0) {
return findNode(t->left, key);
} else {
return findNode(t->right, key);
}
}
ValueType *addNode(BSTNode * & t, const KeyType & key, bool & heightFlag) {
heightFlag = false;
if (t == NULL) {
t = new BSTNode();
t->key = key;
t->value = ValueType();
t->bf = BST_IN_BALANCE;
t->left = t->right = NULL;
heightFlag = true;
nodeCount++;
return &t->value;
}
int sign = compareKeys(key, t->key);
if (sign == 0) return &t->value;
ValueType *vp = NULL;
int bfDelta = BST_IN_BALANCE;
if (sign < 0) {
vp = addNode(t->left, key, heightFlag);
if (heightFlag) bfDelta = BST_LEFT_HEAVY;
} else {
vp = addNode(t->right, key, heightFlag);
if (heightFlag) bfDelta = BST_RIGHT_HEAVY;
}
updateBF(t, bfDelta);
heightFlag = (bfDelta != 0 && t->bf != BST_IN_BALANCE);
return vp;
}
bool removeNode(BSTNode * & t, const KeyType & key) {
if (t == NULL) return false;
int sign = compareKeys(key, t->key);
if (sign == 0) return removeTargetNode(t);
int bfDelta = BST_IN_BALANCE;
if (sign < 0) {
if (removeNode(t->left, key)) bfDelta = BST_RIGHT_HEAVY;
} else {
if (removeNode(t->right, key)) bfDelta = BST_LEFT_HEAVY;
}
updateBF(t, bfDelta);
return bfDelta != 0 && t->bf == BST_IN_BALANCE;
}
bool removeTargetNode(BSTNode * & t) {
BSTNode *toDelete = t;
if (t->left == NULL) {
t = t->right;
delete toDelete;
nodeCount--;
return true;
} else if (t->right == NULL) {
t = t->left;
delete toDelete;
nodeCount--;
return true;
} else {
BSTNode *successor = t->left;
while (successor->right != NULL) {
successor = successor->right;
}
t->key = successor->key;
t->value = successor->value;
if (removeNode(t->left, successor->key)) {
updateBF(t, BST_RIGHT_HEAVY);
return (t->bf == BST_IN_BALANCE);
}
return false;
}
}
void updateBF(BSTNode * & t, int bfDelta) {
t->bf += bfDelta;
if (t->bf < BST_LEFT_HEAVY) {
fixLeftImbalance(t);
} else if (t->bf > BST_RIGHT_HEAVY) {
fixRightImbalance(t);
}
}
void fixLeftImbalance(BSTNode * & t) {
BSTNode *child = t->left;
if (child->bf == BST_RIGHT_HEAVY) {
int oldBF = child->right->bf;
rotateLeft(t->left);
rotateRight(t);
t->bf = BST_IN_BALANCE;
switch (oldBF) {
case BST_LEFT_HEAVY:
t->left->bf = BST_IN_BALANCE;
t->right->bf = BST_RIGHT_HEAVY;
break;
case BST_IN_BALANCE:
t->left->bf = t->right->bf = BST_IN_BALANCE;
break;
case BST_RIGHT_HEAVY:
t->left->bf = BST_LEFT_HEAVY;
t->right->bf = BST_IN_BALANCE;
break;
}
} else if (child->bf == BST_IN_BALANCE) {
rotateRight(t);
t->bf = BST_RIGHT_HEAVY;
t->right->bf = BST_LEFT_HEAVY;
} else {
rotateRight(t);
t->right->bf = t->bf = BST_IN_BALANCE;
}
}
void rotateLeft(BSTNode * & t) {
BSTNode *child = t->right;
t->right = child->left;
child->left = t;
t = child;
}
void fixRightImbalance(BSTNode * & t) {
BSTNode *child = t->right;
if (child->bf == BST_LEFT_HEAVY) {
int oldBF = child->left->bf;
rotateRight(t->right);
rotateLeft(t);
t->bf = BST_IN_BALANCE;
switch (oldBF) {
case BST_LEFT_HEAVY:
t->left->bf = BST_IN_BALANCE;
t->right->bf = BST_RIGHT_HEAVY;
break;
case BST_IN_BALANCE:
t->left->bf = t->right->bf = BST_IN_BALANCE;
break;
case BST_RIGHT_HEAVY:
t->left->bf = BST_LEFT_HEAVY;
t->right->bf = BST_IN_BALANCE;
break;
}
} else if (child->bf == BST_IN_BALANCE) {
rotateLeft(t);
t->bf = BST_LEFT_HEAVY;
t->left->bf = BST_RIGHT_HEAVY;
} else {
rotateLeft(t);
t->left->bf = t->bf = BST_IN_BALANCE;
}
}
void rotateRight(BSTNode * & t) {
BSTNode *child = t->left;
t->left = child->right;
child->right = t;
t = child;
}
void deleteTree(BSTNode *t) {
if (t != NULL) {
deleteTree(t->left);
deleteTree(t->right);
delete t;
}
}
void mapAll(BSTNode *t, void (*fn)(KeyType, ValueType)) const {
if (t != NULL) {
mapAll(t->left, fn);
fn(t->key, t->value);
mapAll(t->right, fn);
}
}
void mapAll(BSTNode *t,
void (*fn)(const KeyType &, const ValueType &)) const {
if (t != NULL) {
mapAll(t->left, fn);
fn(t->key, t->value);
mapAll(t->right, fn);
}
}
template <typename FunctorType>
void mapAll(BSTNode *t, FunctorType fn) const {
if (t != NULL) {
mapAll(t->left, fn);
fn(t->key, t->value);
mapAll(t->right, fn);
}
}
void deepCopy(const Map & other) {
root = copyTree(other.root);
nodeCount = other.nodeCount;
cmpp = other.cmpp->clone();
}
BSTNode *copyTree(BSTNode * const t) {
if (t == NULL) return NULL;
BSTNode *np = new BSTNode();
np->key = t->key;
np->value = t->value;
np->bf = t->bf;
np->left = copyTree(t->left);
np->right = copyTree(t->right);
return np;
}
public:
template <typename CompareType>
explicit Map(CompareType cmp) {
root = NULL;
nodeCount = 0;
cmpp = new TemplateComparator<CompareType>(cmp);
}
int compareKeys(const KeyType & k1, const KeyType & k2) const {
if (cmpp->lessThan(k1, k2)) return -1;
if (cmpp->lessThan(k2, k1)) return +1;
return 0;
}
Map & operator=(const Map & src) {
if (this != &src) {
clear();
delete cmpp;
deepCopy(src);
}
return *this;
}
Map(const Map & src) {
deepCopy(src);
}
class iterator : public std::iterator<std::input_iterator_tag,KeyType> {
private:
struct NodeMarker {
BSTNode *np;
bool processed;
};
const Map *mp;
int index;
Stack<NodeMarker> stack;
void findLeftmostChild() {
BSTNode *np = stack.peek().np;
if (np == NULL) return;
while (np->left != NULL) {
NodeMarker marker = { np->left, false };
stack.push(marker);
np = np->left;
}
}
public:
iterator() {
}
iterator(const Map *mp, bool end) {
this->mp = mp;
if (end || mp->nodeCount == 0) {
index = mp->nodeCount;
} else {
index = 0;
NodeMarker marker = { mp->root, false };
stack.push(marker);
findLeftmostChild();
}
}
iterator(const iterator & it) {
mp = it.mp;
index = it.index;
stack = it.stack;
}
iterator & operator++() {
NodeMarker marker = stack.pop();
BSTNode *np = marker.np;
if (np->right == NULL) {
while (!stack.isEmpty() && stack.peek().processed) {
stack.pop();
}
} else {
marker.processed = true;
stack.push(marker);
marker.np = np->right;
marker.processed = false;
stack.push(marker);
findLeftmostChild();
}
index++;
return *this;
}
iterator operator++(int) {
iterator copy(*this);
operator++();
return copy;
}
bool operator==(const iterator & rhs) {
return mp == rhs.mp && index == rhs.index;
}
bool operator!=(const iterator & rhs) {
return !(*this == rhs);
}
KeyType operator*() {
return stack.peek().np->key;
}
KeyType *operator->() {
return &stack.peek().np->key;
}
friend class Map;
};
iterator begin() const {
return iterator(this, false);
}
iterator end() const {
return iterator(this, true);
}
};
template <typename KeyType, typename ValueType>
Map<KeyType,ValueType>::Map() {
root = NULL;
nodeCount = 0;
cmpp = new TemplateComparator< less<KeyType> >(less<KeyType>());
}
template <typename KeyType, typename ValueType>
Map<KeyType,ValueType>::~Map() {
delete cmpp;
deleteTree(root);
}
template <typename KeyType, typename ValueType>
int Map<KeyType,ValueType>::size() const {
return nodeCount;
}
template <typename KeyType, typename ValueType>
bool Map<KeyType,ValueType>::isEmpty() const {
return nodeCount == 0;
}
template <typename KeyType, typename ValueType>
void Map<KeyType,ValueType>::put(const KeyType & key,
const ValueType & value) {
bool dummy;
*addNode(root, key, dummy) = value;
}
template <typename KeyType, typename ValueType>
ValueType Map<KeyType,ValueType>::get(const KeyType & key) const {
ValueType *vp = findNode(root, key);
if (vp == NULL) return ValueType();
return *vp;
}
template <typename KeyType, typename ValueType>
void Map<KeyType,ValueType>::remove(const KeyType & key) {
removeNode(root, key);
}
template <typename KeyType, typename ValueType>
void Map<KeyType,ValueType>::clear() {
deleteTree(root);
root = NULL;
nodeCount = 0;
}
template <typename KeyType, typename ValueType>
bool Map<KeyType,ValueType>::containsKey(const KeyType & key) const {
return findNode(root, key) != NULL;
}
template <typename KeyType, typename ValueType>
ValueType & Map<KeyType,ValueType>::operator[](const KeyType & key) {
bool dummy;
return *addNode(root, key, dummy);
}
template <typename KeyType, typename ValueType>
ValueType Map<KeyType,ValueType>::operator[](const KeyType & key) const {
return get(key);
}
template <typename KeyType, typename ValueType>
void Map<KeyType,ValueType>::mapAll(void (*fn)(KeyType, ValueType)) const {
mapAll(root, fn);
}
template <typename KeyType, typename ValueType>
void Map<KeyType,ValueType>::mapAll(void (*fn)(const KeyType &,
const ValueType &)) const {
mapAll(root, fn);
}
template <typename KeyType, typename ValueType>
template <typename FunctorType>
void Map<KeyType,ValueType>::mapAll(FunctorType fn) const {
mapAll(root, fn);
}
template <typename KeyType, typename ValueType>
std::string Map<KeyType,ValueType>::toString() {
ostringstream os;
os << *this;
return os.str();
}
template <typename KeyType, typename ValueType>
std::ostream & operator<<(std::ostream & os,
const Map<KeyType,ValueType> & map) {
os << "{";
typename Map<KeyType,ValueType>::iterator begin = map.begin();
typename Map<KeyType,ValueType>::iterator end = map.end();
typename Map<KeyType,ValueType>::iterator it = begin;
while (it != end) {
if (it != begin) os << ", ";
writeGenericValue(os, *it, false);
os << ":";
writeGenericValue(os, map[*it], false);
++it;
}
return os << "}";
}
template <typename KeyType, typename ValueType>
std::istream & operator>>(std::istream & is, Map<KeyType,ValueType> & map) {
char ch;
is >> ch;
if (ch != '{') error("operator >>: Missing {");
map.clear();
is >> ch;
if (ch != '}') {
is.unget();
while (true) {
KeyType key;
readGenericValue(is, key);
is >> ch;
if (ch != ':') error("operator >>: Missing colon after key");
ValueType value;
readGenericValue(is, value);
map[key] = value;
is >> ch;
if (ch == '}') break;
if (ch != ',') {
error(std::string("operator >>: Unexpected character ") + ch);
}
}
}
return is;
}
#endif