#include #include #include #include #include "strbst.h" #define max(a,b) ((a > b) ? a : b) #define min(a,b) ((a < b) ? a : b) typedef char height; typedef struct strbstnode { char* ind; void* data; struct strbstnode* left; struct strbstnode* right; int ht; } strbstnode; struct strbst { strbstnode* head; }; strbst* newbst(void) { strbst* out = malloc(sizeof(strbst)); out->head = NULL; return out; } static strbstnode* newbstnode(char* ind, void* data) { // does not ins strbstnode* out = malloc(sizeof(strbstnode)); out->ind = ind; out->data = data; out->left = out->right = NULL; out->ht = 0; return out; } int getht(strbstnode* node) { return node ? node->ht : -1; } static void rightrot(strbstnode** bst); static void leftrot(strbstnode** bst); int insnode(strbstnode** bst, strbstnode* new) { // Find good parent if (!*bst){ *bst = new; return 0; } strbstnode** node; if (strcmp(new->ind, (*bst)->ind) < 0) node = &((*bst)->left); else node = &((*bst)->right); int heavy = insnode( node, new ); int diff = getht((*bst)->right) - getht((*bst)->left); if (diff > 1) { if (heavy == -1) rightrot(node); leftrot(bst); } else if (diff < -1){ if (heavy == 1) leftrot(node); rightrot(bst); } (*bst)->ht = max( getht((*bst)->left), getht((*bst)->right) )+1; return ( getht((*bst)->right) - getht((*bst)->left) ); } void insbst(strbst* bst, char* ind, void* data) { insnode(&(bst->head), newbstnode(ind, data)); // allows recursion } void* querynode(strbstnode* node, char* ind){ int cmp = strcmp(ind, node->ind); if (cmp > 0) return querynode(node->right, ind); else if (cmp < 0) return querynode(node->left, ind); else if (node) return node->data; else return NULL; } void* querybst(strbst* bst, char* ind) { return querynode(bst->head, ind); } void rightrot(strbstnode** bst) { /* a b / \ / \ b e => c a / \ / \ c d d e */ strbstnode* flip = (*bst)->left->right; // flip = d (*bst)->left->right = *bst; // position of d = a *bst = (*bst)->left; // tip = b (*bst)->right->left = flip; } void leftrot(strbstnode** bst) { /* a b / \ / \ e b => a c / \ / \ d c e d */ strbstnode* flip = (*bst)->right->left; // flip = d (*bst)->right->left = *bst; // position of d = a *bst = (*bst)->right; // tip = b (*bst)->left->right = flip; } void printnode(strbstnode* node) { if (!node) return; if (node->right) printnode(node->right); if (node->left) printnode(node->left); printf("node %s", node->ind); if (node->right) printf(" (right) %s",node->right->ind); if (node->left) printf(" (left) %s", node->left->ind); printf("\n"); } void printbst(strbst* bst) { printnode(bst->head); }