From 82c97250000d865ddae57638dca36c0fc1cb190c Mon Sep 17 00:00:00 2001 From: Donne Martin Date: Sun, 14 Aug 2016 08:20:06 -0400 Subject: [PATCH] Move bst to a class --- graphs_trees/bst/bst.py | 48 ++++++++++------- graphs_trees/bst/bst_challenge.ipynb | 40 ++++++++------ graphs_trees/bst/bst_solution.ipynb | 80 +++++++++++++++++----------- graphs_trees/bst/test_bst.py | 32 ++++++----- 4 files changed, 119 insertions(+), 81 deletions(-) diff --git a/graphs_trees/bst/bst.py b/graphs_trees/bst/bst.py index 877eeff..87c1a00 100644 --- a/graphs_trees/bst/bst.py +++ b/graphs_trees/bst/bst.py @@ -10,22 +10,34 @@ class Node(object): return str(self.data) -def insert(root, data): - # Constraint: Assume we are working with valid ints - if root is None: - root = Node(data) - return root - if data <= root.data: - if root.left is None: - root.left = insert(root.left, data) - root.left.parent = root - return root.left +class Bst(object): + + def __init__(self, root=None): + self.root = root + + def insert(self, data): + if data is None: + raise Exception('Data cannot be None') + if self.root is None: + self.root = Node(data) + return self.root + return self._insert(self.root, data) + + def _insert(self, node, data): + # Constraint: Assume we are working with valid ints + if node is None: + return Node(data) + if data <= node.data: + if node.left is None: + node.left = self._insert(node.left, data) + node.left.parent = node + return node.left + else: + return self._insert(node.left, data) else: - return insert(root.left, data) - else: - if root.right is None: - root.right = insert(root.right, data) - root.right.parent = root - return root.right - else: - return insert(root.right, data) \ No newline at end of file + if node.right is None: + node.right = self._insert(node.right, data) + node.right.parent = node + return node.right + else: + return self._insert(node.right, data) \ No newline at end of file diff --git a/graphs_trees/bst/bst_challenge.ipynb b/graphs_trees/bst/bst_challenge.ipynb index 0c774f5..479c637 100644 --- a/graphs_trees/bst/bst_challenge.ipynb +++ b/graphs_trees/bst/bst_challenge.ipynb @@ -98,9 +98,11 @@ " pass\n", "\n", "\n", - "def insert(root, data):\n", - " # TODO: Implement me\n", - " pass" + "class Bst(object):\n", + "\n", + " def insert(self, data):\n", + " # TODO: Implement me\n", + " pass" ] }, { @@ -156,22 +158,25 @@ " def __init__(self):\n", " self.results = Results()\n", "\n", - " def test_tree(self):\n", - " node = Node(5)\n", - " assert_equal(insert(node, 2).data, 2)\n", - " assert_equal(insert(node, 8).data, 8)\n", - " assert_equal(insert(node, 1).data, 1)\n", - " assert_equal(insert(node, 3).data, 3)\n", - " in_order_traversal(node, self.results.add_result)\n", + " def test_tree_one(self):\n", + " bst = Bst()\n", + " bst.insert(5)\n", + " bst.insert(2)\n", + " bst.insert(8)\n", + " bst.insert(1)\n", + " bst.insert(3)\n", + " in_order_traversal(bst.root, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 5, 8]')\n", " self.results.clear_results()\n", "\n", - " node = insert(None, 1)\n", - " assert_equal(insert(node, 2).data, 2)\n", - " assert_equal(insert(node, 3).data, 3)\n", - " assert_equal(insert(node, 4).data, 4)\n", - " insert(node, 5)\n", - " in_order_traversal(node, self.results.add_result)\n", + " def test_tree_two(self):\n", + " bst = Bst()\n", + " bst.insert(1)\n", + " bst.insert(2)\n", + " bst.insert(3)\n", + " bst.insert(4)\n", + " bst.insert(5)\n", + " in_order_traversal(bst.root, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 4, 5]')\n", "\n", " print('Success: test_tree')\n", @@ -179,7 +184,8 @@ "\n", "def main():\n", " test = TestTree()\n", - " test.test_tree()\n", + " test.test_tree_one()\n", + " test.test_tree_two()\n", "\n", "\n", "if __name__ == '__main__':\n", diff --git a/graphs_trees/bst/bst_solution.ipynb b/graphs_trees/bst/bst_solution.ipynb index eac2500..f155c19 100644 --- a/graphs_trees/bst/bst_solution.ipynb +++ b/graphs_trees/bst/bst_solution.ipynb @@ -127,25 +127,37 @@ " return str(self.data)\n", "\n", "\n", - "def insert(root, data):\n", - " # Constraint: Assume we are working with valid ints\n", - " if root is None:\n", - " root = Node(data)\n", - " return root\n", - " if data <= root.data:\n", - " if root.left is None:\n", - " root.left = insert(root.left, data)\n", - " root.left.parent = root\n", - " return root.left\n", + "class Bst(object):\n", + "\n", + " def __init__(self, root=None):\n", + " self.root = root\n", + "\n", + " def insert(self, data):\n", + " if data is None:\n", + " raise Exception('Data cannot be None')\n", + " if self.root is None:\n", + " self.root = Node(data)\n", + " return self.root\n", + " return self._insert(self.root, data)\n", + "\n", + " def _insert(self, node, data):\n", + " # Constraint: Assume we are working with valid ints\n", + " if node is None:\n", + " return Node(data)\n", + " if data <= node.data:\n", + " if node.left is None:\n", + " node.left = self._insert(node.left, data)\n", + " node.left.parent = node\n", + " return node.left\n", + " else:\n", + " return self._insert(node.left, data)\n", " else:\n", - " return insert(root.left, data)\n", - " else:\n", - " if root.right is None:\n", - " root.right = insert(root.right, data)\n", - " root.right.parent = root\n", - " return root.right\n", - " else:\n", - " return insert(root.right, data)" + " if node.right is None:\n", + " node.right = self._insert(node.right, data)\n", + " node.right.parent = node\n", + " return node.right\n", + " else:\n", + " return self._insert(node.right, data)" ] }, { @@ -213,22 +225,25 @@ " def __init__(self):\n", " self.results = Results()\n", "\n", - " def test_tree(self):\n", - " node = Node(5)\n", - " assert_equal(insert(node, 2).data, 2)\n", - " assert_equal(insert(node, 8).data, 8)\n", - " assert_equal(insert(node, 1).data, 1)\n", - " assert_equal(insert(node, 3).data, 3)\n", - " in_order_traversal(node, self.results.add_result)\n", + " def test_tree_one(self):\n", + " bst = Bst()\n", + " bst.insert(5)\n", + " bst.insert(2)\n", + " bst.insert(8)\n", + " bst.insert(1)\n", + " bst.insert(3)\n", + " in_order_traversal(bst.root, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 5, 8]')\n", " self.results.clear_results()\n", "\n", - " node = insert(None, 1)\n", - " assert_equal(insert(node, 2).data, 2)\n", - " assert_equal(insert(node, 3).data, 3)\n", - " assert_equal(insert(node, 4).data, 4)\n", - " insert(node, 5)\n", - " in_order_traversal(node, self.results.add_result)\n", + " def test_tree_two(self):\n", + " bst = Bst()\n", + " bst.insert(1)\n", + " bst.insert(2)\n", + " bst.insert(3)\n", + " bst.insert(4)\n", + " bst.insert(5)\n", + " in_order_traversal(bst.root, self.results.add_result)\n", " assert_equal(str(self.results), '[1, 2, 3, 4, 5]')\n", "\n", " print('Success: test_tree')\n", @@ -236,7 +251,8 @@ "\n", "def main():\n", " test = TestTree()\n", - " test.test_tree()\n", + " test.test_tree_one()\n", + " test.test_tree_two()\n", "\n", "\n", "if __name__ == '__main__':\n", diff --git a/graphs_trees/bst/test_bst.py b/graphs_trees/bst/test_bst.py index a900ea4..4c775ab 100644 --- a/graphs_trees/bst/test_bst.py +++ b/graphs_trees/bst/test_bst.py @@ -6,22 +6,25 @@ class TestTree(object): def __init__(self): self.results = Results() - def test_tree(self): - node = Node(5) - assert_equal(insert(node, 2).data, 2) - assert_equal(insert(node, 8).data, 8) - assert_equal(insert(node, 1).data, 1) - assert_equal(insert(node, 3).data, 3) - in_order_traversal(node, self.results.add_result) + def test_tree_one(self): + bst = Bst() + bst.insert(5) + bst.insert(2) + bst.insert(8) + bst.insert(1) + bst.insert(3) + in_order_traversal(bst.root, self.results.add_result) assert_equal(str(self.results), '[1, 2, 3, 5, 8]') self.results.clear_results() - node = insert(None, 1) - assert_equal(insert(node, 2).data, 2) - assert_equal(insert(node, 3).data, 3) - assert_equal(insert(node, 4).data, 4) - insert(node, 5) - in_order_traversal(node, self.results.add_result) + def test_tree_two(self): + bst = Bst() + bst.insert(1) + bst.insert(2) + bst.insert(3) + bst.insert(4) + bst.insert(5) + in_order_traversal(bst.root, self.results.add_result) assert_equal(str(self.results), '[1, 2, 3, 4, 5]') print('Success: test_tree') @@ -29,7 +32,8 @@ class TestTree(object): def main(): test = TestTree() - test.test_tree() + test.test_tree_one() + test.test_tree_two() if __name__ == '__main__':