diff --git a/supplier_app/tests/test_view.py b/supplier_app/tests/test_view.py index 9a4ac1b..5bda867 100644 --- a/supplier_app/tests/test_view.py +++ b/supplier_app/tests/test_view.py @@ -1707,6 +1707,8 @@ def setUp(self): 'eb_entity':'1', 'description': 'Bringing the world together through live experiences', + 'email': 'buyer@eventbrite.com', + 'language': 'en', } self.client = Client() self.ap_user = User.objects.create_user(email='ap@eventbrite.com') diff --git a/supplier_app/views.py b/supplier_app/views.py index f308921..d4a1356 100644 --- a/supplier_app/views.py +++ b/supplier_app/views.py @@ -95,11 +95,11 @@ from utils.send_email import company_invitation_notification from utils.htmltopdf import render_to_pdf + class CompanyCreatorView(UserLoginPermissionRequiredMixin, CreateView): model = Company fields = '__all__' template_name = 'supplier_app/AP/company_creation.html' - success_url = reverse_lazy('company-list') permission_required = ( CAN_CREATE_COMPANY_PERM, ) @@ -115,6 +115,7 @@ def form_valid(self, form): company = self.save_company(form) InvitingBuyer.objects.create(company=company, inviting_buyer=self.request.user) EBEntityCompany.objects.create(company=company, eb_entity=EBEntity.objects.get(pk=form.data['eb_entity'])) + company_invite(self.request, company) return HttpResponseRedirect(self.get_success_url()) def save_company(self, forms): @@ -132,6 +133,7 @@ def get_failure_url(self): ) return reverse('company-create') + class CompanyListView(LoginRequiredMixin, ListView): model = Company template_name = 'supplier_app/AP/company_list.html' @@ -550,14 +552,15 @@ def get_success_url(self): @transaction.atomic -def company_invite(request): +def company_invite(request, company=None): try: old_language = translation.get_language() language = request.POST['language'] translation.activate(language) email = [request.POST['email']] - company_id = request.POST['company_id'] - company = Company.objects.get(pk=company_id) + if not company: + company_id = request.POST['company_id'] + company = Company.objects.get(pk=company_id) company_unique_token = CompanyUniqueToken(company=company) company_unique_token.assing_company_token company_unique_token.save() diff --git a/templates/supplier_app/AP/company_creation.html b/templates/supplier_app/AP/company_creation.html index f917a7d..eed4a5d 100644 --- a/templates/supplier_app/AP/company_creation.html +++ b/templates/supplier_app/AP/company_creation.html @@ -32,6 +32,20 @@