#ifndef AVLTree_h
#define AVLTree_h

#include "AllocBuf.h"
#include "ext_compare.h"

template <class T>
class AVLTree;

#pragma interface

template <class T>
class AVLItem : public T {

LAZYCLASS

protected:

	AVLItem * left;
	AVLItem * right;
	int balance;

public:
	
	AVLItem (void) : T()
#ifdef DEBUG
	                    , left((AVLItem *)0), right((AVLItem *)0), balance(0)
#endif /* DEBUG */
                                                                             { };

	AVLItem (const T & val) : T(val)
#ifdef DEBUG
	         , left((AVLItem *)0), right((AVLItem *)0), balance(0)
#endif /* DEBUG */
	        { };

	AVLItem * get_left(void)  const { return left; }
	AVLItem * get_right(void) const { return right; }
	int get_balance(void) const { return balance; }

	friend class AVLTree<T>;

};

LAZYOPS(template<class T>,AVLItem<T>)


template <class T>
class AVLTree {

protected:
	
	AVLItem<T> * top;
	long items;
	
	void recurse_clear(AVLItem<T> * item);
	int recurse_insert(AVLItem<T> * & aktnode, AVLItem<T> * item, int cmpresult);
	int recurse_remove(AVLItem<T> * & aktnode, AVLItem<T> * item,
                      int search_next); 
	int repair_right (AVLItem<T> * & aktnode); 
	int repair_left  (AVLItem<T> * & aktnode); 
#ifdef AVL_DEBUG
	int recurse_check(AVLItem<T> * aktnode, int & cnt);                             
#endif /* AVL_DEBUG */
	
public:

	AVLTree (void) : top((AVLItem<T> *)0), items(0) { };

	void clear(void);

	~AVLTree (void) {
		clear();
	}

#ifdef AVL_DEBUG
	int check (void) const;
#endif /* AVL_DEBUG */

	AVLItem<T> * find(const T & val) const;
	AVLItem<T> * find_first(const T & val) const;

	void insert(AVLItem<T> * item);
	
	int insert(const T & val) {
		AVLItem<T> * item = new AVLItem<T> (val);
		if (! item) return -1;
		insert(item);
		return 0;
	}
	
	void remove(AVLItem<T> * item);
	int remove(const T & val);	
	
	AVLItem<T> * get_head(void) const;
	AVLItem<T> * get_tail(void) const;
	AVLItem<T> * get_prev( AVLItem<T> * item ) const;
	AVLItem<T> * get_next( AVLItem<T> * item ) const ;
	
	long size(void) const { return items; }
	long length(void) const { return items; }
	AVLItem<T> * get_top(void)   const { return top; }
	
};


#ifdef AVL_DEBUG
#include <iostream.h>

template <class T>
int AVLTree<T>::recurse_check( AVLItem<T> * aktnode, int & cnt) {

	cnt++;
	
	int hl = 0;
	int hr = 0;
	
	if (aktnode->left) {
		hl = recurse_check(aktnode->left, cnt);
		if (hl == 0) {
			cerr << "failure at node " << *((T *)aktnode->left) << '\n';
			return 0;
		}
	}

	if (aktnode->right) {
		hr = recurse_check(aktnode->right, cnt);
		if (hr == 0) {
			cerr << "failure at node " << *((T *)aktnode->right) << '\n';
			return 0;
		}
	}

	int diff = hl - hr;
	if (diff < 0) diff = -diff;
	if (diff > 1) {
		return 0;
	}

	if (hl == hr  &&  aktnode->balance !=  0) {
		return 0;
	}
	if (hl <  hr  &&  aktnode->balance !=  1) {
		return 0;
	}
	if (hl >  hr  &&  aktnode->balance != -1) {
		return 0;
	}
	
	return (hl > hr)? hl+1 : hr+1;
}

template <class T>
int AVLTree<T>::check(void) {

	if (! top) {
		if (items == 0) return 0;
		return -1;
	}

	int cnt = 0;
	
	if (0 == recurse_check ( top, cnt )) return -1;
	if (cnt != items) return -1;
	return 0;
}

#endif /* AVL_DEBUG */

template <class T>
void AVLTree<T>::remove(AVLItem<T> * item) {
	
	recurse_remove( top, item, 0 );
	
	items--;
	
}

template <class T> 
int AVLTree<T>::remove(const T & val) {
	AVLItem<T> * tmp = find(val);
	if (! tmp) return 0;
	
	int res = 1;
	
	// remove all following items with equal value
	AVLItem<T> * n = get_next(tmp);
	for (;;) {
		if (!n) break;
		if (n->compare(val) != 0) break;
		AVLItem<T> * t = get_next(n);
		remove(n);
		delete n;
		res++;
		n = t;
	}
	
	// remove all previous items with equal value
	n = get_prev(tmp);
	for (;;) {
		if (!n) break;
		if (n->compare(val) != 0) break;
		AVLItem<T> * t = get_prev(n);
		remove(n);
		delete n;
		res++;
		n = t;
	}
	
	remove(tmp);
	delete tmp;
	
	return res;
}

template <class T>
int AVLTree<T>::repair_left (AVLItem<T> * & aktnode) {

	if (aktnode->balance == 0) {
		aktnode->balance = 1;
		return 0;
	} else if (aktnode->balance < 0) {
		aktnode->balance = 0;
		return -1;
	} else {
		// aktnode->balance > 0
		
		AVLItem<T> * right = aktnode->right;
		
		if (right->balance < 0) {
			
			AVLItem<T> * a = aktnode;
			AVLItem<T> * d = aktnode->left;
			AVLItem<T> * c = aktnode->right;
			AVLItem<T> * e = c->left;
			AVLItem<T> * f = c->right;
			AVLItem<T> * x = e->left;
			AVLItem<T> * y = e->right;
			
			aktnode = e;
			aktnode->left = a;
			aktnode->right = c;
			aktnode->left->left = d;
			aktnode->left->right = x;
			aktnode->right->left = y;
			aktnode->right->right = f;
			
			int balancebuf = e->balance;
			aktnode->balance = 0;
			aktnode->left->balance = 0;
			aktnode->right->balance = 0;
			if (balancebuf > 0) {
				aktnode->left->balance = -1;
			} else if (balancebuf < 0) {
				aktnode->right->balance = 1;
			}
			
			return -1;
						
		} else {
			// right->balance >= 0
			
			int result = 0;
			
			if (right->balance == 0) {
				right->balance = -1;
				aktnode->balance = 1;
			} else {
				// right->balance > 0
				right->balance = 0;
				aktnode->balance = 0;
				result = -1;
			}

			aktnode->right = right->left;
			right->left = aktnode;
			aktnode = right;
			
			return result;
		}
	} 
}

template <class T>
int AVLTree<T>::repair_right (AVLItem<T> * & aktnode) {

	if (aktnode->balance == 0) {
		aktnode->balance = -1;
		return 0;
	} else if (aktnode->balance > 0) {
		aktnode->balance = 0;
		return -1;
	} else {
		// aktnode->balance < 0
		
		AVLItem<T> * left = aktnode->left;
		
		if (left->balance > 0) {
			
			AVLItem<T> * a = aktnode;
			AVLItem<T> * d = aktnode->right;
			AVLItem<T> * c = aktnode->left;
			AVLItem<T> * e = c->right;
			AVLItem<T> * f = c->left;
			AVLItem<T> * x = e->right;
			AVLItem<T> * y = e->left;
			
			aktnode = e;
			aktnode->right = a;
			aktnode->left = c;
			aktnode->right->right = d;
			aktnode->right->left = x;
			aktnode->left->right = y;
			aktnode->left->left = f;
			
			int balancebuf = e->balance;
			aktnode->balance = 0;
			aktnode->left->balance = 0;
			aktnode->right->balance = 0;
			if (balancebuf < 0) {
				aktnode->right->balance = 1;
			} else if (balancebuf > 0) {
				aktnode->left->balance = -1;
			}
			
			return -1;
						
		} else {
			// left->balance <= 0
			
			int result = 0;
			
			if (left->balance == 0) {
				left->balance = 1;
				aktnode->balance = -1;
			} else {
				// left->balance < 0
				left->balance = 0;
				aktnode->balance = 0;
				result = -1;
			}

			aktnode->left = left->right;
			left->right = aktnode;
			aktnode = left;
			
			return result;
		}
	} 
}

template <class T>
int AVLTree<T>::recurse_remove(AVLItem<T> * & aktnode, AVLItem<T> * item,
                               int search_next) {

	static AVLItem<T> * successor_buf;

	if (search_next) {
		
		if (aktnode->left) {
			if ( recurse_remove ( aktnode->left, item, -1) ) {
				// left tree has been shortened - repair
				return repair_left(aktnode);
			}	
			return 0;
		} else {
			
			successor_buf = aktnode;
			aktnode = aktnode->right;
		
			return -1;
		}
	
	} else {
		if (aktnode == item) {
			// item found, remove it...
			
			if ( aktnode->right && aktnode->left ) {
				// aktnode has two childs
				
				int need_repair = recurse_remove( aktnode->right, item, -1 );
				
				aktnode = successor_buf;
				successor_buf->right = item->right;
				successor_buf->left = item->left;
				successor_buf->balance = item->balance;
				
				// item is now no longer part of the tree and may be deleted

				if (need_repair) {
					// right tree has been shortened - repair
					return repair_right(aktnode);
				} else {
					return 0;
				}
				
			} else {
				// aktnode has no or one child
				
				if (aktnode->left) {
					aktnode = aktnode->left;
				} else {
					aktnode = aktnode->right;
				}
				
				// item is now no longer part of the tree and may be deleted
				return -1;
			}
		
		} else {		
			// still searching for "item"
			
			if (ext_compare((const T *)item, (const T *)aktnode) < 0) {
				if ( recurse_remove(aktnode->left, item, 0) ) {
					// left tree has been shortened - repair
					return repair_left(aktnode);
				}
				return 0;
			} else {
				if ( recurse_remove(aktnode->right, item, 0) ) {
					// right tree has been shortened - repair
					return repair_right(aktnode);
				}
				return 0;
			}
		}
	}			
}

template <class T>
AVLItem<T> * AVLTree<T>::get_head(void) const {

	AVLItem<T> * aktnode = top;

	if (aktnode) {
		while (aktnode->left) {
			aktnode = aktnode->left;
		}
	}
	
	return aktnode;
}

template <class T>
AVLItem<T> * AVLTree<T>::get_tail(void) const {

	AVLItem<T> * aktnode = top;

	if (aktnode) {
		while (aktnode->right) {
			aktnode = aktnode->right;
		}
	}
	
	return aktnode;

}

template <class T>
AVLItem<T> * AVLTree<T>::get_prev( AVLItem<T> * item ) const {
	if (! item->left) {
		AVLItem<T> * aktnode = top;
		
		AVLItem<T> * last = (AVLItem<T> *)0;

		for (;;) {
		
			if (item == aktnode) return last;
			
			if (ext_compare((const T *)item, (const T *)aktnode) < 0) {
				aktnode = aktnode->left;
			} else {
				last = aktnode;
				aktnode = aktnode->right;
			}
		}
	} else {
		AVLItem<T> * aktnode = item->left;
		
		while (aktnode->right) {
			aktnode = aktnode->right;
		}
		
		return aktnode;
	}
}

template <class T>
AVLItem<T> * AVLTree<T>::get_next( AVLItem<T> * item ) const {
	
	if (! item->right) {
		AVLItem<T> * aktnode = top;
		
		AVLItem<T> * last = (AVLItem<T> *)0;

		for (;;) {
		
			if (item == aktnode) return last;
			
			if (ext_compare((const T *)item, (const T *)aktnode) < 0) {
				last = aktnode;
				aktnode = aktnode->left;
			} else {
				aktnode = aktnode->right;
			}
		}
		
	} else {
		AVLItem<T> * aktnode = item->right;
		
		while (aktnode->left) {
			aktnode = aktnode->left;
		}
		
		return aktnode;
	}
}

template <class T>
void AVLTree<T>::insert(AVLItem<T> * item) {

	item->left = (AVLItem<T> *) 0;
	item->right = (AVLItem<T> *) 0;
	item->balance = 0;

	if (items == 0) {
		top = item;
	} else {
		recurse_insert(top, item, ext_compare((const T *)item, (const T *)top));
	}

	items++;
	
}


template <class T>
int AVLTree<T>::recurse_insert(AVLItem<T> * & aktnode, AVLItem<T> * item,
                               int cmpresult) {

	if (cmpresult < 0) {
		// insert left from aktnode
		if (aktnode->balance >= 0) {
		
			if (! aktnode->left) {
				// insert now to the left of aktnode
				aktnode->left = item;
				aktnode->balance--;
				if (aktnode->balance < 0) return -1; // length of subtree increased
				return 0;
			}
			
		 	if (recurse_insert( aktnode->left, item,
		 	                    ext_compare((const T *)item, (const T *)aktnode->left)) ) {
		 		aktnode->balance--;
		 		if (aktnode->balance < 0) return -1;
		 		return 0;
		 	} else {
		 		return 0;
		 	}
		} else {
			// insert needs to be performed left from aktnode, but balance is
			// already "left-heavy"
			
			if (aktnode->left->balance != 0) {
				// we can proceed to the left node, as there's always
				// enough space left to insert one node without increasing
				// the height.

#ifdef DEBUG				
				if (recurse_insert( aktnode->left, item,
		 	                    ext_compare((const T *)item, (const T *)aktnode->left)) ) {
					*((unsigned long *) ~0) = 0x44;
			 	}
#else
				recurse_insert( aktnode->left, item,
				                ext_compare((const T *)item, (const T *)aktnode->left));
#endif /* DEBUG */
			 	return 0;
			}

			int nextcmp = ext_compare((const T *)item, (const T *)aktnode->left);
			
			if (nextcmp < 0) {
				AVLItem<T> * left = aktnode->left;
				
				aktnode->left = left->right;
				left->right = aktnode;
				
				aktnode->balance++;
				left->balance++;
				
				aktnode = left;
				
				recurse_insert( aktnode, item, nextcmp );
				return 0;
				
			} else {
									
				if (! aktnode->left->right) {
					// "item" will become aktnode
					
					item->right = aktnode;
					item->left = aktnode->left;
					aktnode->left = (AVLItem<T> *)0;
					aktnode = item;
					
					aktnode->balance = 0;
					aktnode->right->balance++;
					
					return 0;
					
				} else {
					AVLItem<T> * nextakt = aktnode->left->right;
					
					if (nextakt->balance != 0) {
#ifdef DEBUG
						if (recurse_insert( aktnode->left, item, nextcmp) ) {
							*((unsigned long *) ~0) = 0x44;
		 				}
#else
						recurse_insert( aktnode->left, item, nextcmp);
#endif /* DEBUG */
		 				return 0;
					} else {
						// left->right will become aktnode
					
						aktnode->left->right = nextakt->left;
						nextakt->left = aktnode->left;
						aktnode->left = nextakt->right;
						nextakt->right = aktnode;
						
						nextakt->left->balance = -1;
						nextakt->balance = 0;
						aktnode->balance = 1;	
						
						aktnode = nextakt;
						
						recurse_insert( aktnode, item,
						                ext_compare((const T *)item, (const T *)aktnode));
						return 0;
					}
					
				}
			}
		}
		
	} else {
		// insert right from aktnode
			
		if (aktnode->balance <= 0) {
		
			if (! aktnode->right) {
				// insert now to the right of aktnode
				aktnode->right = item;
				aktnode->balance++;
				if (aktnode->balance > 0) return -1; // length of subtree increased
				return 0;
			}
			
		 	if (recurse_insert( aktnode->right, item,
		 	                    ext_compare((const T *)item, (const T *)aktnode->right)) ) {
		 		aktnode->balance++;
		 		if (aktnode->balance > 0) return -1;
		 		return 0;
		 	} else {
		 		return 0;
		 	}
		} else {
			// insert needs to be performed right from aktnode, but balance is
			// already "right-heavy"
			
			if (aktnode->right->balance != 0) {
				// we can proceed to the right node, as there's always
				// enough space left to insert one node without increasing
				// the height.
				
#ifdef DEBUG
				if (recurse_insert( aktnode->right, item,
		 	                    ext_compare((const T *)item, (const T *)aktnode->right)) ) {
					*((unsigned long *) ~0) = 0x44;
			 	}
#else
				recurse_insert( aktnode->right, item, 
				                ext_compare((const T *)item, (const T *)aktnode->right));
#endif /* DEBUG */
			 	return 0;
			}

			int nextcmp = ext_compare((const T *)item, (const T *)aktnode->right);
			
			if (nextcmp > 0) {
				AVLItem<T> * right = aktnode->right;
			
				aktnode->right = right->left;
				right->left = aktnode;
				
				aktnode->balance--;
				right->balance--;
				
				aktnode = right;
				
				recurse_insert( aktnode, item, nextcmp );
				return 0;
				
			} else {
									
				if (! aktnode->right->left) {
					// "item" will become aktnode
					
					item->left = aktnode;
					item->right = aktnode->right;
					aktnode->right = (AVLItem<T> *)0;
					aktnode = item;
					
					aktnode->balance = 0;
					aktnode->left->balance--;
					
					return 0;
					
				} else {
					AVLItem<T> * nextakt = aktnode->right->left;
					if (nextakt->balance != 0) {
#ifdef DEBUG
						if (recurse_insert( aktnode->right, item, nextcmp) ) {
							*((unsigned long *) ~0) = 0x44;
		 				}
#else
						recurse_insert( aktnode->right, item, nextcmp);
#endif /* DEBUG */
		 				return 0;
					} else {
						// right->left will become aktnode
						
						aktnode->right->left = nextakt->right;
						nextakt->right = aktnode->right;
						aktnode->right = nextakt->left;
						nextakt->left = aktnode;
						
						nextakt->right->balance = 1;
						nextakt->balance = 0;
						aktnode->balance = -1;						
						
						aktnode = nextakt;
						
						recurse_insert( aktnode, item,
						                ext_compare((const T *)item, (const T *)aktnode));
						return 0;
					}
				}
			}
		}
	}
}

template <class T>
AVLItem<T> * AVLTree<T>::find(const T & val) const {

	AVLItem<T> * aktnode = top;
	if (aktnode == (AVLItem<T> *) 0) return 0;

	for (;;) {
		int cmp = val.compare(*aktnode);
		
		if (cmp == 0) return aktnode;
		
		if (cmp < 0) {
			if (! aktnode->left) return 0;
			aktnode = aktnode->left;
		} else {
			if (! aktnode->right) return 0;
			aktnode = aktnode->right;
		}
	}
}

template <class T>
AVLItem<T> * AVLTree<T>::find_first(const T & val) const {
	AVLItem<T> * p = find(val);
	if (!p) return 0;
	
	for (;;) {
		AVLItem<T> * t = get_prev(p);
		if (!t) break;
		if (t->compare(val) != 0) break;
		p = t;
	}

	return p;
}

template <class T>
void AVLTree<T>::recurse_clear( AVLItem<T> * item ) {

	if (item->left) {
		recurse_clear( item->left );
	}
	
	if (item->right) {
		recurse_clear( item->right );
	}
	
	delete item;
}

template <class T>
void AVLTree<T>::clear(void) {
	if (top) {
		recurse_clear(top);
		top = (AVLItem<T> *) 0;
	}
	items = 0;
}

//--------------------------------------------------------
// non-member functions
//--------------------------------------------------------

template <class T, class U>
AVLItem<T> * AVL_fast_find(const AVLTree<T> & tree, const U & val) {

	AVLItem<T> * aktnode = tree.get_top();
	if (aktnode == (AVLItem<T> *) 0) return 0;

	for (;;) {
		int cmp = compare(val, (const T &)(*aktnode));
		
		if (cmp == 0) return aktnode;
		
		if (cmp < 0) {
			aktnode = aktnode->get_left();
		} else {
			aktnode = aktnode->get_right();
		}
		if (! aktnode) return 0;
	}
}

//--------------------------------------------------------
// I/O routines (non-member functions)
//--------------------------------------------------------

#ifdef DEBUG

#include "Str.h"

template <class T>
void recurse_AVLTout ( ostream & out, AVLItem<T> * item,
                              unsigned long & indent ) {
	
	indent++;

	Str i(indent);
	i.fill();
	Str j(indent+1);
	j.fill();
		
	out << i;
	
	char c = '-';
	if (item->get_balance() < 0) c = '/';
	if (item->get_balance() > 0) c = '\\';
	
	out << *((T *)item) << " " << c << '\n';

	if (item->get_left()) {
		recurse_AVLTout (out, item->get_left(), indent);
	} else {
		out << j << "#\n";
	}
	 
	if (item->get_right()) {
		recurse_AVLTout (out, item->get_right(), indent);
	} else {
		out << j << "#\n";
	}
	  
	indent--;
}

template <class T>
ostream & operator<< (ostream & out, const AVLTree<T> & val) {

	unsigned long indent = 0;

	out << "AVLTree:\n";

	recurse_AVLTout ( out, val.get_top(), indent);

	out << "========\n";

	return out;
}

#endif /* DEBUG */

#include "Bstream.h"

template <class T>
BOstream & operator<< (BOstream & out, const AVLTree<T> & val) {

	out << val.size();

	AVLItem<T> * i = val.get_head();
	while (i) {
		out << *((T *)i);
		i = val.get_next(i);
	}
	return out;
}

template <class T>
BIstream & operator>> (BIstream & in, AVLTree<T> & val) {

	long i;
	in >> i;

	while (i--) {
		AVLItem<T> * t = new AVLItem<T>;
		in >> ((T &)(*t));
		val.insert(t);
	}

	return in;
}


#endif /* AVLTree_h */


syntax highlighted by Code2HTML, v. 0.9.1