AVL (Insertion)

Note:

This code was written during a crunch period and isn't perfect. There will be some errant spacing, some files will be using namespace std, etc. But it's all still usable and can be a handy guideline if you're learning Data Structures.

#include <string>
#include <vector>
#include <queue>

struct Node
{
    Node(int val) : value(val){};
    Node *left;
    Node *right;

    int value;
    int height;
};

class AVL
{
public:
    void add(int val);
    Node *getRoot();

private:
    Node *root = nullptr;

    Node *singleLeftRotation(Node *current);
    Node *singleRightRotation(Node *current);
    Node *leftRightRotation(Node *current);
    Node *rightLeftRotation(Node *current);

    int getHeight(Node *current);
    int balanceFactor(Node *current);

    Node *insert(Node *parent, int val);
};

void AVL::add(int value)
{
    root = insert(getRoot(), value);
}

Node *AVL::getRoot()
{
    return root;
}

Node *AVL::singleLeftRotation(Node *current)
{
    Node *newParent = current->right;

    current->right = newParent->left;
    newParent->left = current;

    current->height = std::max(getHeight(current->left), getHeight(current->right)) + 1;
    newParent->height = std::max(getHeight(newParent->left), getHeight(newParent->right)) + 1;

    return newParent;
}

Node *AVL::singleRightRotation(Node *current)
{
    Node *newParent = current->left;

    current->left = newParent->right;
    newParent->right = current;

    current->height = std::max(getHeight(current->left), getHeight(current->right)) + 1;
    newParent->height = std::max(getHeight(newParent->left), getHeight(newParent->right)) + 1;

    return newParent;
}

Node *AVL::leftRightRotation(Node *current)
{
    current->left = singleLeftRotation(current->left);
    return singleRightRotation(current);
}

Node *AVL::rightLeftRotation(Node *current)
{
    current->right = singleRightRotation(current->right);
    return singleLeftRotation(current);
}

int AVL::getHeight(Node *current)
{
    if (current == nullptr)
        return -1;
    return current->height;
}

int AVL::balanceFactor(Node *current)
{
    if (current == nullptr)
        return 0;
    return getHeight(current->left) - getHeight(current->right);
}

Node *AVL::insert(Node *parent, int value)
{
    if (parent == nullptr)
    {
        return new Node(value);
    }
    else if (value < parent->value)
    {
        parent->left = insert(parent->left, value);
        int bf = balanceFactor(parent);

        if (bf == 2)
        {
            if (value < parent->left->value)
            {
                /*
                    2
                1
                */
                std::cout << "single right" << std::endl;
                parent = singleRightRotation(parent);
            }
            else
            {
                /*
                        1
                    2
                */
                std::cout << "left right" << std::endl;
                parent = leftRightRotation(parent);
            }
        }
    }
    else if (value > parent->value)
    {
        parent->right = insert(parent->right, value);
        int bf = balanceFactor(parent);
        if (bf == -2)
        {
            if (value > parent->right->value)
            {
                /*
                1
                    2
                */
                std::cout << "single left" << std::endl;
                parent = singleLeftRotation(parent);
            }
            else
            {
                /*
                1
                2
                */
                std::cout << "right left" << std::endl;
                parent = rightLeftRotation(parent);
            }
        }
    }
    else
    {
        std::cout << "duplicate" << std::endl;
        return parent;
    }

    parent->height = std::max(getHeight(parent->left), getHeight(parent->right)) + 1;
    return parent;
}

void printLevelOrder(Node *node)
{
    std::queue<Node *> q;
    q.push(node);

    while (!q.empty())
    {
        int size = q.size();
        for (int i = 0; i < size; i++)
        {
            Node *cu = q.front();
            std::cout << cu->value << " ";

            q.pop();
            if (cu->left != nullptr)
                q.push(cu->left);
            if (cu->right != nullptr)
                q.push(cu->right);
        }
        std::cout << std::endl;
    }
}

int main(int argc, char *argv[])
{
    AVL avlTree;
    avlTree.add(20);
    avlTree.add(4);
    avlTree.add(26);
    avlTree.add(9);
    avlTree.add(21);
    avlTree.add(2);
    avlTree.add(7);
    avlTree.add(11);

    printLevelOrder(avlTree.getRoot());
}