aboutsummaryrefslogtreecommitdiff
path: root/strbst.c
blob: 7bb53b16859c87fd54bfa285fd1d789492311171 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "strbst.h"
#define max(a,b) ((a > b) ? a : b)
#define min(a,b) ((a < b) ? a : b)

typedef char height;

strbst* newbst(void) {
    strbst* out = malloc(sizeof(strbst));
    out->head = NULL;
    return out;
}

static strbstnode* newbstnode(char* ind, void* data) { // does not ins
    strbstnode* out = malloc(sizeof(strbstnode));
    out->ind = ind;
    out->data = data;
    out->left = out->right = NULL;
    out->ht = 0;
    return out;
}

int getht(strbstnode* node) {
    return node ? node->ht : -1;
}

static void rightrot(strbstnode** bst);
static void leftrot(strbstnode** bst);

int insnode(strbstnode** bst, strbstnode* new) { // Find good parent
    if (!*bst){
        *bst = new;
        return 0;
    }
    strbstnode** node;
    if (strcmp(new->ind, (*bst)->ind) < 0)
        node = &((*bst)->left);
    else
        node = &((*bst)->right);
    int heavy = insnode( node, new );
    int diff = getht((*bst)->right) - getht((*bst)->left);
    if (diff > 1) {
        if (heavy == -1)
            rightrot(node);
        leftrot(bst);
    } else if (diff < -1){
        if (heavy == 1)
            leftrot(node);
        rightrot(bst);
    }
    (*bst)->ht = max( getht((*bst)->left), getht((*bst)->right) )+1;
    return ( getht((*bst)->right) - getht((*bst)->left) );
}

void insbst(strbst* bst, char* ind, void* data) {
    insnode(&(bst->head), newbstnode(ind, data)); // allows recursion
}

void* querynode(strbstnode* node, char* ind){
    int cmp = strcmp(ind, node->ind);
    if (cmp > 0)
        return querynode(node->right, ind);
    else if (cmp < 0)
        return querynode(node->left, ind);
    else if (node)
        return node->data;
    else
        return NULL;
}

void* querybst(strbst* bst, char* ind) {
    return querynode(bst->head, ind);
}

void rightrot(strbstnode** bst) {
/*
    a         b
   / \       / \
  b   e =>  c   a
 / \           / \
c   d         d   e
*/
    strbstnode* flip = (*bst)->left->right; // flip = d
    (*bst)->left->right = *bst; // position of d = a
    *bst = (*bst)->left; // tip = b
    (*bst)->right->left = flip;
}

void leftrot(strbstnode** bst) {
/*
  a            b
 / \          / \
e   b   =>   a   c
   / \      / \
  d   c    e   d
*/
    strbstnode* flip = (*bst)->right->left; // flip = d
    (*bst)->right->left = *bst; // position of d = a
    *bst = (*bst)->right; // tip = b
    (*bst)->left->right = flip;
}

void printnode(strbstnode* node) {
    if (!node) return;
    if (node->right)
        printnode(node->right);
    if (node->left)
        printnode(node->left);
    printf("node %s", node->ind);
    if (node->right)
        printf(" (right) %s",node->right->ind);
    if (node->left)
        printf(" (left) %s", node->left->ind);
    printf("\n");
}

void printbst(strbst* bst) {
    printnode(bst->head);
}