diff --git a/server/channels/api4/license_test.go b/server/channels/api4/license_test.go index 73af78717bf..884b3fbe7c5 100644 --- a/server/channels/api4/license_test.go +++ b/server/channels/api4/license_test.go @@ -109,7 +109,7 @@ func TestUploadLicenseFile(t *testing.T) { licenseBytes, _ := json.Marshal(license) licenseStr := string(licenseBytes) - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil) utils.LicenseValidator = &mockLicenseValidator licenseManagerMock := &mocks.LicenseInterface{} @@ -144,7 +144,7 @@ func TestUploadLicenseFile(t *testing.T) { licenseBytes, err := json.Marshal(license) require.NoError(t, err) - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseBytes)) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseBytes), nil) utils.LicenseValidator = &mockLicenseValidator resp, err := th.SystemAdminClient.UploadLicenseFile(context.Background(), []byte("")) @@ -177,7 +177,7 @@ func TestUploadLicenseFile(t *testing.T) { licenseBytes, _ := json.Marshal(license) licenseStr := string(licenseBytes) - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil) utils.LicenseValidator = &mockLicenseValidator @@ -275,7 +275,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) { mockLicenseValidator := mocks2.LicenseValidatorIface{} defer testutils.ResetLicenseValidator() - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil) utils.LicenseValidator = &mockLicenseValidator licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() @@ -306,7 +306,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) { mockLicenseValidator := mocks2.LicenseValidatorIface{} defer testutils.ResetLicenseValidator() - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil) utils.LicenseValidator = &mockLicenseValidator licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() @@ -344,7 +344,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) { mockLicenseValidator := mocks2.LicenseValidatorIface{} defer testutils.ResetLicenseValidator() - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil) utils.LicenseValidator = &mockLicenseValidator licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() @@ -405,7 +405,7 @@ func TestRequestTrialLicense(t *testing.T) { mockLicenseValidator := mocks2.LicenseValidatorIface{} defer testutils.ResetLicenseValidator() - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil) utils.LicenseValidator = &mockLicenseValidator licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() @@ -435,7 +435,7 @@ func TestRequestTrialLicense(t *testing.T) { mockLicenseValidator := mocks2.LicenseValidatorIface{} defer testutils.ResetLicenseValidator() - mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON)) + mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil) utils.LicenseValidator = &mockLicenseValidator licenseManagerMock := &mocks.LicenseInterface{} licenseManagerMock.On("CanStartTrial").Return(true, nil).Once() diff --git a/server/channels/app/license.go b/server/channels/app/license.go index 65ede1188a0..8c5a71d2651 100644 --- a/server/channels/app/license.go +++ b/server/channels/app/license.go @@ -129,7 +129,7 @@ func (s *Server) SetLicense(license *model.License) bool { return s.platform.SetLicense(license) } -func (s *Server) ValidateAndSetLicenseBytes(b []byte) bool { +func (s *Server) ValidateAndSetLicenseBytes(b []byte) error { return s.platform.ValidateAndSetLicenseBytes(b) } diff --git a/server/channels/app/platform/license.go b/server/channels/app/platform/license.go index e237b8d90ad..2f068266b13 100644 --- a/server/channels/app/platform/license.go +++ b/server/channels/app/platform/license.go @@ -54,9 +54,9 @@ func (ps *PlatformService) LoadLicense() { // ENV var overrides all other sources of license. licenseStr := os.Getenv(LicenseEnv) if licenseStr != "" { - license, err := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr)) - if err != nil { - ps.logger.Error("Failed to read license set in environment.", mlog.Err(err)) + license, appErr := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr)) + if appErr != nil { + ps.logger.Error("Failed to read license set in environment.", mlog.Err(appErr)) return } @@ -74,7 +74,9 @@ func (ps *PlatformService) LoadLicense() { } } - if ps.ValidateAndSetLicenseBytes([]byte(licenseStr)) { + if err := ps.ValidateAndSetLicenseBytes([]byte(licenseStr)); err != nil { + ps.logger.Info("License key from ENV is invalid.", mlog.Err(err)) + } else { ps.logger.Info("License key from ENV is valid, unlocking enterprise features.") } return @@ -88,9 +90,10 @@ func (ps *PlatformService) LoadLicense() { if !model.IsValidId(licenseId) { // Lets attempt to load the file from disk since it was missing from the DB - license, licenseBytes := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation) - - if license != nil { + license, licenseBytes, err := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation) + if err != nil { + ps.logger.Warn("Failed to get license from disk", mlog.Err(err)) + } else { if _, err := ps.SaveLicense(licenseBytes); err != nil { ps.logger.Error("Failed to save license key loaded from disk.", mlog.Err(err)) } else { @@ -101,19 +104,23 @@ func (ps *PlatformService) LoadLicense() { record, nErr := ps.Store.License().Get(sqlstore.RequestContextWithMaster(c), licenseId) if nErr != nil { - ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr)) + ps.logger.Warn("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr)) ps.SetLicense(nil) return } - ps.ValidateAndSetLicenseBytes([]byte(record.Bytes)) - ps.logger.Info("License key valid unlocking enterprise features.") + err := ps.ValidateAndSetLicenseBytes([]byte(record.Bytes)) + if err != nil { + ps.logger.Info("License key is invalid.") + } + + ps.logger.Info("License key is valid, unlocking enterprise features.") } func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) { - success, licenseStr := utils.LicenseValidator.ValidateLicense(licenseBytes) - if !success { - return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest) + licenseStr, err := utils.LicenseValidator.ValidateLicense(licenseBytes) + if err != nil { + return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err) } var license model.License @@ -231,19 +238,19 @@ func (ps *PlatformService) SetLicense(license *model.License) bool { return false } -func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) bool { - if success, licenseStr := utils.LicenseValidator.ValidateLicense(b); success { - var license model.License - if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { - ps.logger.Warn("Failed to decode license from JSON", mlog.Err(jsonErr)) - return false - } - ps.SetLicense(&license) - return true +func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) error { + licenseStr, err := utils.LicenseValidator.ValidateLicense(b) + if err != nil { + return errors.Wrap(err, "Failed to decode license from JSON") } - ps.logger.Warn("No valid enterprise license found") - return false + var license model.License + if err := json.Unmarshal([]byte(licenseStr), &license); err != nil { + return errors.Wrap(err, "Failed to decode license from JSON") + } + + ps.SetLicense(&license) + return nil } func (ps *PlatformService) SetClientLicense(m map[string]string) { diff --git a/server/channels/utils/license.go b/server/channels/utils/license.go index b2e51ab7d0c..fdd13b92d84 100644 --- a/server/channels/utils/license.go +++ b/server/channels/utils/license.go @@ -11,6 +11,7 @@ import ( "encoding/base64" "encoding/json" "encoding/pem" + "fmt" "io" "net/http" "os" @@ -32,33 +33,32 @@ func init() { type LicenseValidatorIface interface { LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) - ValidateLicense(signed []byte) (bool, string) + ValidateLicense(signed []byte) (string, error) } type LicenseValidatorImpl struct { } func (l *LicenseValidatorImpl) LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) { - success, licenseStr := l.ValidateLicense(licenseBytes) - if !success { - return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest) + licenseStr, err := l.ValidateLicense(licenseBytes) + if err != nil { + return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err) } var license model.License - if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { - return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr) + if err := json.Unmarshal([]byte(licenseStr), &license); err != nil { + return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(err) } return &license, nil } -func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) { +func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (string, error) { decoded := make([]byte, base64.StdEncoding.DecodedLen(len(signed))) _, err := base64.StdEncoding.Decode(decoded, signed) if err != nil { - mlog.Error("Encountered error decoding license", mlog.Err(err)) - return false, "" + return "", fmt.Errorf("encountered error decoding license: %w", err) } // remove null terminator @@ -67,8 +67,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) { } if len(decoded) <= 256 { - mlog.Error("Signed license not long enough") - return false, "" + return "", fmt.Errorf("Signed license not long enough") } plaintext := decoded[:len(decoded)-256] @@ -85,8 +84,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) { public, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { - mlog.Error("Encountered error signing license", mlog.Err(err)) - return false, "" + return "", fmt.Errorf("Encountered error signing license: %w", err) } rsaPublic := public.(*rsa.PublicKey) @@ -97,37 +95,34 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) { err = rsa.VerifyPKCS1v15(rsaPublic, crypto.SHA512, d, signature) if err != nil { - mlog.Error("Invalid signature", mlog.Err(err)) - return false, "" + return "", fmt.Errorf("Invalid signature: %w", err) } - return true, string(plaintext) + return string(plaintext), nil } -func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte) { +func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte, error) { fileName := GetLicenseFileLocation(location) + mlog.Info("License key has not been uploaded. Loading license key from disk.", mlog.String("filename", fileName)) + if _, err := os.Stat(fileName); err != nil { - mlog.Debug("We could not find the license key in the database or on disk at", mlog.String("filename", fileName)) - return nil, nil + return nil, nil, fmt.Errorf("We could not find the license key on disk at %s: %w", fileName, err) } - mlog.Info("License key has not been uploaded. Loading license key from disk at", mlog.String("filename", fileName)) licenseBytes := GetLicenseFileFromDisk(fileName) - success, licenseStr := LicenseValidator.ValidateLicense(licenseBytes) - if !success { - mlog.Error("Found license key at %v but it appears to be invalid.", mlog.String("filename", fileName)) - return nil, nil + licenseStr, err := LicenseValidator.ValidateLicense(licenseBytes) + if err != nil { + return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err) } var license model.License if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil { - mlog.Error("Failed to decode license from JSON", mlog.Err(jsonErr)) - return nil, nil + return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err) } - return &license, licenseBytes + return &license, licenseBytes, nil } func GetLicenseFileFromDisk(fileName string) []byte { diff --git a/server/channels/utils/license_test.go b/server/channels/utils/license_test.go index 96fce2bd2e4..4a88a869375 100644 --- a/server/channels/utils/license_test.go +++ b/server/channels/utils/license_test.go @@ -19,12 +19,12 @@ var validTestLicense = []byte("eyJpZCI6InpvZ3c2NW44Z2lmajVkbHJoYThtYnUxcGl3Iiwia func TestValidateLicense(t *testing.T) { t.Run("should fail with junk data", func(t *testing.T) { b1 := []byte("junk") - ok, _ := LicenseValidator.ValidateLicense(b1) - require.False(t, ok, "should have failed - bad license") + _, err := LicenseValidator.ValidateLicense(b1) + require.Error(t, err, "should have failed - bad license") b2 := []byte("junkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunk") - ok, _ = LicenseValidator.ValidateLicense(b2) - require.False(t, ok, "should have failed - bad license") + _, err = LicenseValidator.ValidateLicense(b2) + require.Error(t, err, "should have failed - bad license") }) t.Run("should not panic on shorter than expected input", func(t *testing.T) { @@ -42,8 +42,8 @@ func TestValidateLicense(t *testing.T) { err = encoder.Close() require.NoError(t, err) - ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes()) - require.False(t, ok) + str, err := LicenseValidator.ValidateLicense(licenseData.Bytes()) + require.Error(t, err) require.Empty(t, str) }) @@ -61,8 +61,8 @@ func TestValidateLicense(t *testing.T) { err = encoder.Close() require.NoError(t, err) - ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes()) - require.False(t, ok) + str, err := LicenseValidator.ValidateLicense(licenseData.Bytes()) + require.Error(t, err) require.Empty(t, str) }) @@ -70,8 +70,8 @@ func TestValidateLicense(t *testing.T) { os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest) defer os.Unsetenv("MM_SERVICEENVIRONMENT") - ok, str := LicenseValidator.ValidateLicense(nil) - require.False(t, ok) + str, err := LicenseValidator.ValidateLicense(nil) + require.Error(t, err) require.Empty(t, str) }) @@ -79,8 +79,8 @@ func TestValidateLicense(t *testing.T) { os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest) defer os.Unsetenv("MM_SERVICEENVIRONMENT") - ok, str := LicenseValidator.ValidateLicense(validTestLicense) - require.True(t, ok) + str, err := LicenseValidator.ValidateLicense(validTestLicense) + require.NoError(t, err) require.NotEmpty(t, str) }) @@ -88,8 +88,8 @@ func TestValidateLicense(t *testing.T) { os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentProduction) defer os.Unsetenv("MM_SERVICEENVIRONMENT") - ok, str := LicenseValidator.ValidateLicense(validTestLicense) - require.False(t, ok) + str, err := LicenseValidator.ValidateLicense(validTestLicense) + require.Error(t, err) require.Empty(t, str) }) } @@ -117,7 +117,7 @@ func TestGetLicenseFileFromDisk(t *testing.T) { fileBytes := GetLicenseFileFromDisk(f.Name()) require.NotEmpty(t, fileBytes, "should have read the file") - success, _ := LicenseValidator.ValidateLicense(fileBytes) - assert.False(t, success, "should have been an invalid file") + _, err = LicenseValidator.ValidateLicense(fileBytes) + assert.Error(t, err, "should have been an invalid file") }) } diff --git a/server/channels/utils/mocks/LicenseValidatorIface.go b/server/channels/utils/mocks/LicenseValidatorIface.go index 928c76bbc10..fe77edecba0 100644 --- a/server/channels/utils/mocks/LicenseValidatorIface.go +++ b/server/channels/utils/mocks/LicenseValidatorIface.go @@ -43,24 +43,24 @@ func (_m *LicenseValidatorIface) LicenseFromBytes(licenseBytes []byte) (*model.L } // ValidateLicense provides a mock function with given fields: signed -func (_m *LicenseValidatorIface) ValidateLicense(signed []byte) (bool, string) { +func (_m *LicenseValidatorIface) ValidateLicense(signed []byte) (string, error) { ret := _m.Called(signed) - var r0 bool - var r1 string - if rf, ok := ret.Get(0).(func([]byte) (bool, string)); ok { + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (string, error)); ok { return rf(signed) } - if rf, ok := ret.Get(0).(func([]byte) bool); ok { + if rf, ok := ret.Get(0).(func([]byte) string); ok { r0 = rf(signed) } else { - r0 = ret.Get(0).(bool) + r0 = ret.Get(0).(string) } - if rf, ok := ret.Get(1).(func([]byte) string); ok { + if rf, ok := ret.Get(1).(func([]byte) error); ok { r1 = rf(signed) } else { - r1 = ret.Get(1).(string) + r1 = ret.Error(1) } return r0, r1