import os import socket import time import unittest import MySQLdb from django.conf import settings as djsettings from django.core.management.base import CommandError from django.urls import reverse from orchestra.admin.utils import change_url from orchestra.contrib.orchestration.models import Route, Server from orchestra.utils.sys import sshrun from orchestra.utils.tests import (BaseLiveServerTestCase, random_ascii, save_response_on_error, snapshot_on_error) from selenium.webdriver.support.select import Select from ... import backends, settings from ...models import Database, DatabaseUser TEST_REST_API = int(os.getenv('TEST_REST_API', '0')) class DatabaseTestMixin(object): MASTER_SERVER = os.environ.get('ORCHESTRA_SECOND_SERVER', 'localhost') DEPENDENCIES = ( 'orchestra.contrib.orchestration', 'orcgestra.apps.databases', ) def setUp(self): super(DatabaseTestMixin, self).setUp() self.add_route() djsettings.DEBUG = True def add_route(self): raise NotImplementedError def save(self): raise NotImplementedError def add(self): raise NotImplementedError def delete(self): raise NotImplementedError def update(self): raise NotImplementedError def disable(self): raise NotImplementedError def add_group(self, username, groupname): raise NotImplementedError def test_add(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.validate_create_table(dbname, username, password) def test_delete(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.validate_create_table(dbname, username, password) self.delete(dbname) self.delete_user(username) self.validate_delete(dbname, username, password) self.validate_delete_user(dbname, username) def test_change_password(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.addCleanup(self.delete_user, username) self.validate_create_table(dbname, username, password) new_password = '@!?%spppP001' % random_ascii(5) self.change_password(username, new_password) self.validate_login_error(dbname, username, password) self.validate_create_table(dbname, username, new_password) def test_add_user(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.addCleanup(self.delete_user, username) self.validate_create_table(dbname, username, password) username2 = '%s_dbuser' % random_ascii(5) password2 = '@!?%spppP001' % random_ascii(5) self.add_user(username2, password2) self.addCleanup(self.delete_user, username2) self.validate_login_error(dbname, username2, password2) self.add_user_to_db(username2, dbname) self.validate_create_table(dbname, username, password) self.validate_create_table(dbname, username2, password2) def test_delete_user(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.validate_create_table(dbname, username, password) username2 = '%s_dbuser' % random_ascii(5) password2 = '@!?%spppP001' % random_ascii(5) self.add_user(username2, password2) self.add_user_to_db(username2, dbname) self.delete_user(username) self.validate_delete_user(username, password) self.validate_login_error(dbname, username, password) self.validate_create_table(dbname, username2, password2) self.delete_user(username2) self.validate_login_error(dbname, username2, password2) self.validate_delete_user(username2, password2) def test_swap_user(self): dbname = '%s_database' % random_ascii(5) username = '%s_dbuser' % random_ascii(5) password = '@!?%spppP001' % random_ascii(5) self.add(dbname, username, password) self.addCleanup(self.delete, dbname) self.addCleanup(self.delete_user, username) self.validate_create_table(dbname, username, password) username2 = '%s_dbuser' % random_ascii(5) password2 = '@!?%spppP001' % random_ascii(5) self.add_user(username2, password2) self.addCleanup(self.delete_user, username2) self.swap_user(username, username2, dbname) self.validate_login_error(dbname, username, password) self.validate_create_table(dbname, username2, password2) class MySQLControllerMixin(object): db_type = 'mysql' def setUp(self): super(MySQLControllerMixin, self).setUp() # Get local ip address used to reach self.MASTER_SERVER s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect((self.MASTER_SERVER, 22)) settings.DATABASES_DEFAULT_HOST = s.getsockname()[0] s.close() def add_route(self): server = Server.objects.create(name=self.MASTER_SERVER) backend = backends.MySQLController.get_name() match = "database.type == '%s'" % self.db_type Route.objects.create(backend=backend, match=match, host=server) match = "databaseuser.type == '%s'" % self.db_type backend = backends.MySQLUserController.get_name() Route.objects.create(backend=backend, match=match, host=server) def validate_create_table(self, name, username, password): db = MySQLdb.connect(host=self.MASTER_SERVER, port=3306, user=username, passwd=password, db=name) cur = db.cursor() cur.execute('CREATE TABLE table_%s ( id INT ) ;' % random_ascii(10)) def validate_login_error(self, dbname, username, password): self.assertRaises(MySQLdb.OperationalError, self.validate_create_table, dbname, username, password ) def validate_delete(self, dbname, username, password): self.validate_login_error(dbname, username, password) self.assertRaises(CommandError, sshrun, self.MASTER_SERVER, 'mysql %s' % dbname, display=False) def validate_delete_user(self, name, username): context = { 'name': name, 'username': username, } self.assertEqual('', sshrun(self.MASTER_SERVER, """mysql mysql -e 'SELECT * FROM db WHERE db="%(name)s";'""" % context, display=False).stdout) self.assertEqual('', sshrun(self.MASTER_SERVER, """mysql mysql -e 'SELECT * FROM user WHERE user="%(username)s";'""" % context, display=False).stdout) @unittest.skipUnless(TEST_REST_API, "REST API tests") class RESTDatabaseMixin(DatabaseTestMixin): def setUp(self): super(RESTDatabaseMixin, self).setUp() self.rest_login() @save_response_on_error def add(self, dbname, username, password): user = self.rest.databaseusers.create(username=username, password=password, type=self.db_type) users = [{ 'username': user.username }] self.rest.databases.create(name=dbname, users=users, type=self.db_type) @save_response_on_error def delete(self, dbname): self.rest.databases.retrieve(name=dbname).delete() @save_response_on_error def change_password(self, username, password): user = self.rest.databaseusers.retrieve(username=username).get() user.set_password(password) @save_response_on_error def add_user(self, username, password): self.rest.databaseusers.create(username=username, password=password, type=self.db_type) @save_response_on_error def add_user_to_db(self, username, dbname): user = self.rest.databaseusers.retrieve(username=username).get() db = self.rest.databases.retrieve(name=dbname).get() db.users.append(user) db.save() @save_response_on_error def delete_user(self, username): self.rest.databaseusers.retrieve(username=username).delete() @save_response_on_error def swap_user(self, username, username2, dbname): user = self.rest.databaseusers.retrieve(username=username2).get() db = self.rest.databases.retrieve(name=dbname).get() db.users = db.users.exclude(username=username) db.users.append(user) db.save() class AdminDatabaseMixin(DatabaseTestMixin): def setUp(self): super(AdminDatabaseMixin, self).setUp() self.admin_login() @snapshot_on_error def add(self, dbname, username, password): url = self.live_server_url + reverse('admin:databases_database_add') self.selenium.get(url) type_input = self.selenium.find_element_by_id('id_type') type_select = Select(type_input) type_select.select_by_value(self.db_type) name_field = self.selenium.find_element_by_id('id_name') name_field.send_keys(dbname) username_field = self.selenium.find_element_by_id('id_username') username_field.send_keys(username) password_field = self.selenium.find_element_by_id('id_password1') password_field.send_keys(password) password_field = self.selenium.find_element_by_id('id_password2') password_field.send_keys(password) name_field.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def delete(self, dbname): db = Database.objects.get(name=dbname) self.admin_delete(db) @snapshot_on_error def change_password(self, username, password): user = DatabaseUser.objects.get(username=username) self.admin_change_password(user, password) @snapshot_on_error def add_user(self, username, password): url = self.live_server_url + reverse('admin:databases_databaseuser_add') self.selenium.get(url) type_input = self.selenium.find_element_by_id('id_type') type_select = Select(type_input) type_select.select_by_value(self.db_type) username_field = self.selenium.find_element_by_id('id_username') username_field.send_keys(username) password_field = self.selenium.find_element_by_id('id_password1') password_field.send_keys(password) password_field = self.selenium.find_element_by_id('id_password2') password_field.send_keys(password) username_field.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def add_user_to_db(self, username, dbname): database = Database.objects.get(name=dbname, type=self.db_type) url = self.live_server_url + change_url(database) self.selenium.get(url) user = DatabaseUser.objects.get(username=username, type=self.db_type) users_from = self.selenium.find_element_by_id('id_users_from') users_select = Select(users_from) users_select.select_by_value(str(user.pk)) add_user = self.selenium.find_element_by_id('id_users_add_link') add_user.click() save = self.selenium.find_element_by_name('_save') save.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def swap_user(self, username, username2, dbname): database = Database.objects.get(name=dbname, type=self.db_type) url = self.live_server_url + change_url(database) self.selenium.get(url) # remove user "username" user = DatabaseUser.objects.get(username=username, type=self.db_type) users_to = self.selenium.find_element_by_id('id_users_to') users_select = Select(users_to) users_select.select_by_value(str(user.pk)) remove_user = self.selenium.find_element_by_id('id_users_remove_link') remove_user.click() time.sleep(0.2) # add user "username2" user = DatabaseUser.objects.get(username=username2, type=self.db_type) users_from = self.selenium.find_element_by_id('id_users_from') users_select = Select(users_from) users_select.select_by_value(str(user.pk)) add_user = self.selenium.find_element_by_id('id_users_add_link') add_user.click() time.sleep(0.2) save = self.selenium.find_element_by_name('_save') save.submit() self.assertNotEqual(url, self.selenium.current_url) @snapshot_on_error def delete_user(self, username): user = DatabaseUser.objects.get(username=username) self.admin_delete(user) class RESTMysqlDatabaseTest(MySQLControllerMixin, RESTDatabaseMixin, BaseLiveServerTestCase): pass class AdminMysqlDatabaseTest(MySQLControllerMixin, AdminDatabaseMixin, BaseLiveServerTestCase): pass