Skip to content

Commit

Permalink
[MM-56653] Improve license loading errors (mattermost#26050)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzei authored Apr 5, 2024
1 parent 1a9355b commit 71e26b8
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 84 deletions.
16 changes: 8 additions & 8 deletions server/channels/api4/license_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, _ := json.Marshal(license)
licenseStr := string(licenseBytes)

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil)
utils.LicenseValidator = &mockLicenseValidator

licenseManagerMock := &mocks.LicenseInterface{}
Expand Down Expand Up @@ -144,7 +144,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, err := json.Marshal(license)
require.NoError(t, err)

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseBytes))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseBytes), nil)
utils.LicenseValidator = &mockLicenseValidator

resp, err := th.SystemAdminClient.UploadLicenseFile(context.Background(), []byte(""))
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestUploadLicenseFile(t *testing.T) {
licenseBytes, _ := json.Marshal(license)
licenseStr := string(licenseBytes)

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, licenseStr)
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(licenseStr, nil)

utils.LicenseValidator = &mockLicenseValidator

Expand Down Expand Up @@ -275,7 +275,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
Expand Down Expand Up @@ -344,7 +344,7 @@ func TestRequestTrialLicenseWithExtraFields(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
Expand Down Expand Up @@ -405,7 +405,7 @@ func TestRequestTrialLicense(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
Expand Down Expand Up @@ -435,7 +435,7 @@ func TestRequestTrialLicense(t *testing.T) {
mockLicenseValidator := mocks2.LicenseValidatorIface{}
defer testutils.ResetLicenseValidator()

mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(true, string(licenseJSON))
mockLicenseValidator.On("ValidateLicense", mock.Anything).Return(string(licenseJSON), nil)
utils.LicenseValidator = &mockLicenseValidator
licenseManagerMock := &mocks.LicenseInterface{}
licenseManagerMock.On("CanStartTrial").Return(true, nil).Once()
Expand Down
2 changes: 1 addition & 1 deletion server/channels/app/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (s *Server) SetLicense(license *model.License) bool {
return s.platform.SetLicense(license)
}

func (s *Server) ValidateAndSetLicenseBytes(b []byte) bool {
func (s *Server) ValidateAndSetLicenseBytes(b []byte) error {
return s.platform.ValidateAndSetLicenseBytes(b)
}

Expand Down
55 changes: 31 additions & 24 deletions server/channels/app/platform/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ func (ps *PlatformService) LoadLicense() {
// ENV var overrides all other sources of license.
licenseStr := os.Getenv(LicenseEnv)
if licenseStr != "" {
license, err := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr))
if err != nil {
ps.logger.Error("Failed to read license set in environment.", mlog.Err(err))
license, appErr := utils.LicenseValidator.LicenseFromBytes([]byte(licenseStr))
if appErr != nil {
ps.logger.Error("Failed to read license set in environment.", mlog.Err(appErr))
return
}

Expand All @@ -74,7 +74,9 @@ func (ps *PlatformService) LoadLicense() {
}
}

if ps.ValidateAndSetLicenseBytes([]byte(licenseStr)) {
if err := ps.ValidateAndSetLicenseBytes([]byte(licenseStr)); err != nil {
ps.logger.Info("License key from ENV is invalid.", mlog.Err(err))
} else {
ps.logger.Info("License key from ENV is valid, unlocking enterprise features.")
}
return
Expand All @@ -88,9 +90,10 @@ func (ps *PlatformService) LoadLicense() {

if !model.IsValidId(licenseId) {
// Lets attempt to load the file from disk since it was missing from the DB
license, licenseBytes := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation)

if license != nil {
license, licenseBytes, err := utils.GetAndValidateLicenseFileFromDisk(*ps.Config().ServiceSettings.LicenseFileLocation)
if err != nil {
ps.logger.Warn("Failed to get license from disk", mlog.Err(err))
} else {
if _, err := ps.SaveLicense(licenseBytes); err != nil {
ps.logger.Error("Failed to save license key loaded from disk.", mlog.Err(err))
} else {
Expand All @@ -101,19 +104,23 @@ func (ps *PlatformService) LoadLicense() {

record, nErr := ps.Store.License().Get(sqlstore.RequestContextWithMaster(c), licenseId)
if nErr != nil {
ps.logger.Error("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.logger.Warn("License key from https://mattermost.com required to unlock enterprise features.", mlog.Err(nErr))
ps.SetLicense(nil)
return
}

ps.ValidateAndSetLicenseBytes([]byte(record.Bytes))
ps.logger.Info("License key valid unlocking enterprise features.")
err := ps.ValidateAndSetLicenseBytes([]byte(record.Bytes))
if err != nil {
ps.logger.Info("License key is invalid.")
}

ps.logger.Info("License key is valid, unlocking enterprise features.")
}

func (ps *PlatformService) SaveLicense(licenseBytes []byte) (*model.License, *model.AppError) {
success, licenseStr := utils.LicenseValidator.ValidateLicense(licenseBytes)
if !success {
return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest)
licenseStr, err := utils.LicenseValidator.ValidateLicense(licenseBytes)
if err != nil {
return nil, model.NewAppError("addLicense", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err)
}

var license model.License
Expand Down Expand Up @@ -231,19 +238,19 @@ func (ps *PlatformService) SetLicense(license *model.License) bool {
return false
}

func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) bool {
if success, licenseStr := utils.LicenseValidator.ValidateLicense(b); success {
var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
ps.logger.Warn("Failed to decode license from JSON", mlog.Err(jsonErr))
return false
}
ps.SetLicense(&license)
return true
func (ps *PlatformService) ValidateAndSetLicenseBytes(b []byte) error {
licenseStr, err := utils.LicenseValidator.ValidateLicense(b)
if err != nil {
return errors.Wrap(err, "Failed to decode license from JSON")
}

ps.logger.Warn("No valid enterprise license found")
return false
var license model.License
if err := json.Unmarshal([]byte(licenseStr), &license); err != nil {
return errors.Wrap(err, "Failed to decode license from JSON")
}

ps.SetLicense(&license)
return nil
}

func (ps *PlatformService) SetClientLicense(m map[string]string) {
Expand Down
49 changes: 22 additions & 27 deletions server/channels/utils/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net/http"
"os"
Expand All @@ -32,33 +33,32 @@ func init() {

type LicenseValidatorIface interface {
LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError)
ValidateLicense(signed []byte) (bool, string)
ValidateLicense(signed []byte) (string, error)
}

type LicenseValidatorImpl struct {
}

func (l *LicenseValidatorImpl) LicenseFromBytes(licenseBytes []byte) (*model.License, *model.AppError) {
success, licenseStr := l.ValidateLicense(licenseBytes)
if !success {
return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest)
licenseStr, err := l.ValidateLicense(licenseBytes)
if err != nil {
return nil, model.NewAppError("LicenseFromBytes", model.InvalidLicenseError, nil, "", http.StatusBadRequest).Wrap(err)
}

var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(jsonErr)
if err := json.Unmarshal([]byte(licenseStr), &license); err != nil {
return nil, model.NewAppError("LicenseFromBytes", "api.unmarshal_error", nil, "", http.StatusInternalServerError).Wrap(err)
}

return &license, nil
}

func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (string, error) {
decoded := make([]byte, base64.StdEncoding.DecodedLen(len(signed)))

_, err := base64.StdEncoding.Decode(decoded, signed)
if err != nil {
mlog.Error("Encountered error decoding license", mlog.Err(err))
return false, ""
return "", fmt.Errorf("encountered error decoding license: %w", err)
}

// remove null terminator
Expand All @@ -67,8 +67,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {
}

if len(decoded) <= 256 {
mlog.Error("Signed license not long enough")
return false, ""
return "", fmt.Errorf("Signed license not long enough")
}

plaintext := decoded[:len(decoded)-256]
Expand All @@ -85,8 +84,7 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {

public, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
mlog.Error("Encountered error signing license", mlog.Err(err))
return false, ""
return "", fmt.Errorf("Encountered error signing license: %w", err)
}

rsaPublic := public.(*rsa.PublicKey)
Expand All @@ -97,37 +95,34 @@ func (l *LicenseValidatorImpl) ValidateLicense(signed []byte) (bool, string) {

err = rsa.VerifyPKCS1v15(rsaPublic, crypto.SHA512, d, signature)
if err != nil {
mlog.Error("Invalid signature", mlog.Err(err))
return false, ""
return "", fmt.Errorf("Invalid signature: %w", err)
}

return true, string(plaintext)
return string(plaintext), nil
}

func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte) {
func GetAndValidateLicenseFileFromDisk(location string) (*model.License, []byte, error) {
fileName := GetLicenseFileLocation(location)

mlog.Info("License key has not been uploaded. Loading license key from disk.", mlog.String("filename", fileName))

if _, err := os.Stat(fileName); err != nil {
mlog.Debug("We could not find the license key in the database or on disk at", mlog.String("filename", fileName))
return nil, nil
return nil, nil, fmt.Errorf("We could not find the license key on disk at %s: %w", fileName, err)
}

mlog.Info("License key has not been uploaded. Loading license key from disk at", mlog.String("filename", fileName))
licenseBytes := GetLicenseFileFromDisk(fileName)

success, licenseStr := LicenseValidator.ValidateLicense(licenseBytes)
if !success {
mlog.Error("Found license key at %v but it appears to be invalid.", mlog.String("filename", fileName))
return nil, nil
licenseStr, err := LicenseValidator.ValidateLicense(licenseBytes)
if err != nil {
return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err)
}

var license model.License
if jsonErr := json.Unmarshal([]byte(licenseStr), &license); jsonErr != nil {
mlog.Error("Failed to decode license from JSON", mlog.Err(jsonErr))
return nil, nil
return nil, nil, fmt.Errorf("Found license key at %s but it appears to be invalid: %w", fileName, err)
}

return &license, licenseBytes
return &license, licenseBytes, nil
}

func GetLicenseFileFromDisk(fileName string) []byte {
Expand Down
32 changes: 16 additions & 16 deletions server/channels/utils/license_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ var validTestLicense = []byte("eyJpZCI6InpvZ3c2NW44Z2lmajVkbHJoYThtYnUxcGl3Iiwia
func TestValidateLicense(t *testing.T) {
t.Run("should fail with junk data", func(t *testing.T) {
b1 := []byte("junk")
ok, _ := LicenseValidator.ValidateLicense(b1)
require.False(t, ok, "should have failed - bad license")
_, err := LicenseValidator.ValidateLicense(b1)
require.Error(t, err, "should have failed - bad license")

b2 := []byte("junkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunkjunk")
ok, _ = LicenseValidator.ValidateLicense(b2)
require.False(t, ok, "should have failed - bad license")
_, err = LicenseValidator.ValidateLicense(b2)
require.Error(t, err, "should have failed - bad license")
})

t.Run("should not panic on shorter than expected input", func(t *testing.T) {
Expand All @@ -42,8 +42,8 @@ func TestValidateLicense(t *testing.T) {
err = encoder.Close()
require.NoError(t, err)

ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.Error(t, err)
require.Empty(t, str)
})

Expand All @@ -61,35 +61,35 @@ func TestValidateLicense(t *testing.T) {
err = encoder.Close()
require.NoError(t, err)

ok, str := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(licenseData.Bytes())
require.Error(t, err)
require.Empty(t, str)
})

t.Run("should reject invalid license in test service environment", func(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest)
defer os.Unsetenv("MM_SERVICEENVIRONMENT")

ok, str := LicenseValidator.ValidateLicense(nil)
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(nil)
require.Error(t, err)
require.Empty(t, str)
})

t.Run("should validate valid test license in test service environment", func(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentTest)
defer os.Unsetenv("MM_SERVICEENVIRONMENT")

ok, str := LicenseValidator.ValidateLicense(validTestLicense)
require.True(t, ok)
str, err := LicenseValidator.ValidateLicense(validTestLicense)
require.NoError(t, err)
require.NotEmpty(t, str)
})

t.Run("should reject valid test license in production service environment", func(t *testing.T) {
os.Setenv("MM_SERVICEENVIRONMENT", model.ServiceEnvironmentProduction)
defer os.Unsetenv("MM_SERVICEENVIRONMENT")

ok, str := LicenseValidator.ValidateLicense(validTestLicense)
require.False(t, ok)
str, err := LicenseValidator.ValidateLicense(validTestLicense)
require.Error(t, err)
require.Empty(t, str)
})
}
Expand Down Expand Up @@ -117,7 +117,7 @@ func TestGetLicenseFileFromDisk(t *testing.T) {
fileBytes := GetLicenseFileFromDisk(f.Name())
require.NotEmpty(t, fileBytes, "should have read the file")

success, _ := LicenseValidator.ValidateLicense(fileBytes)
assert.False(t, success, "should have been an invalid file")
_, err = LicenseValidator.ValidateLicense(fileBytes)
assert.Error(t, err, "should have been an invalid file")
})
}
Loading

0 comments on commit 71e26b8

Please sign in to comment.