diff --git a/migrations/tables/k0s_tokens.yaml b/migrations/tables/k0s_tokens.yaml new file mode 100644 index 0000000000..1ae6760972 --- /dev/null +++ b/migrations/tables/k0s_tokens.yaml @@ -0,0 +1,21 @@ +apiVersion: schemas.schemahero.io/v1alpha4 +kind: Table +metadata: + name: k0s-tokens +spec: + name: k0s_tokens + requires: [] + schema: + rqlite: + strict: true + primaryKey: + - token + columns: + - name: token + type: text + constraints: + notNull: true + - name: roles + type: text + constraints: + notNull: true diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go index 1ddde1d543..8ee55ea376 100644 --- a/pkg/handlers/handlers.go +++ b/pkg/handlers/handlers.go @@ -277,8 +277,8 @@ func RegisterSessionAuthRoutes(r *mux.Router, kotsStore store.Store, handler KOT // HelmVM r.Name("HelmVM").Path("/api/v1/helmvm").HandlerFunc(NotImplemented) - r.Name("GenerateHelmVMNodeJoinCommand").Path("/api/v1/helmvm/generate-node-join-command").Methods("POST"). - HandlerFunc(middleware.EnforceAccess(policy.ClusterWrite, handler.GenerateHelmVMNodeJoinCommand)) + r.Name("GenerateK0sNodeJoinCommand").Path("/api/v1/helmvm/generate-node-join-command").Methods("POST"). + HandlerFunc(middleware.EnforceAccess(policy.ClusterWrite, handler.GenerateK0sNodeJoinCommand)) r.Name("DrainHelmVMNode").Path("/api/v1/helmvm/nodes/{nodeName}/drain").Methods("POST"). HandlerFunc(middleware.EnforceAccess(policy.ClusterWrite, handler.DrainHelmVMNode)) r.Name("DeleteHelmVMNode").Path("/api/v1/helmvm/nodes/{nodeName}").Methods("DELETE"). @@ -355,6 +355,9 @@ func RegisterUnauthenticatedRoutes(handler *Handler, kotsStore store.Store, debu // These handlers should be called by the application only. loggingRouter.Path("/license/v1/license").Methods("GET").HandlerFunc(handler.GetPlatformLicenseCompatibility) loggingRouter.Path("/api/v1/app/custom-metrics").Methods("POST").HandlerFunc(handler.GetSendCustomAppMetricsHandler(kotsStore)) + + // This handler requires a valid token in the query + loggingRouter.Path("/api/v1/embedded-cluster/join").Methods("GET").HandlerFunc(handler.GetK0sNodeJoinCommand) } func RegisterLicenseIDAuthRoutes(r *mux.Router, kotsStore store.Store, handler KOTSHandler) { diff --git a/pkg/handlers/helmvm_node_join_command.go b/pkg/handlers/helmvm_node_join_command.go index 6b8abe4654..b4d6a0da4f 100644 --- a/pkg/handlers/helmvm_node_join_command.go +++ b/pkg/handlers/helmvm_node_join_command.go @@ -2,54 +2,115 @@ package handlers import ( "encoding/json" + "fmt" "net/http" - "time" "github.com/replicatedhq/kots/pkg/helmvm" "github.com/replicatedhq/kots/pkg/k8sutil" "github.com/replicatedhq/kots/pkg/logger" + "github.com/replicatedhq/kots/pkg/store/kotsstore" ) -type GenerateHelmVMNodeJoinCommandResponse struct { +type GenerateK0sNodeJoinCommandResponse struct { Command []string `json:"command"` - Expiry string `json:"expiry"` +} + +type GetK0sNodeJoinCommandResponse struct { + ClusterID string `json:"clusterID"` + K0sJoinCommand string `json:"k0sJoinCommand"` + K0sToken string `json:"k0sToken"` } type GenerateHelmVMNodeJoinCommandRequest struct { Roles []string `json:"roles"` } -func (h *Handler) GenerateHelmVMNodeJoinCommand(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GenerateK0sNodeJoinCommand(w http.ResponseWriter, r *http.Request) { generateHelmVMNodeJoinCommandRequest := GenerateHelmVMNodeJoinCommandRequest{} if err := json.NewDecoder(r.Body).Decode(&generateHelmVMNodeJoinCommandRequest); err != nil { - logger.Error(err) + logger.Error(fmt.Errorf("failed to decode request body: %w", err)) w.WriteHeader(http.StatusBadRequest) return } + store := kotsstore.StoreFromEnv() + token, err := store.SetK0sInstallCommandRoles(generateHelmVMNodeJoinCommandRequest.Roles) + if err != nil { + logger.Error(fmt.Errorf("failed to set k0s install command roles: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + client, err := k8sutil.GetClientset() + if err != nil { + logger.Error(fmt.Errorf("failed to get clientset: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + nodeJoinCommand, err := helmvm.GenerateAddNodeCommand(r.Context(), client, token) + if err != nil { + logger.Error(fmt.Errorf("failed to generate add node command: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + JSON(w, http.StatusOK, GenerateK0sNodeJoinCommandResponse{ + Command: []string{nodeJoinCommand}, + }) +} + +// this function relies on the token being valid for authentication +func (h *Handler) GetK0sNodeJoinCommand(w http.ResponseWriter, r *http.Request) { + // read query string, ensure that the token is valid + token := r.URL.Query().Get("token") + store := kotsstore.StoreFromEnv() + roles, err := store.GetK0sInstallCommandRoles(token) + if err != nil { + logger.Error(fmt.Errorf("failed to get k0s install command roles: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + // use roles to generate join token etc client, err := k8sutil.GetClientset() if err != nil { - logger.Error(err) + logger.Error(fmt.Errorf("failed to get clientset: %w", err)) w.WriteHeader(http.StatusInternalServerError) return } k0sRole := "worker" - for _, role := range generateHelmVMNodeJoinCommandRequest.Roles { + for _, role := range roles { if role == "controller" { k0sRole = "controller" break } } - command, expiry, err := helmvm.GenerateAddNodeCommand(r.Context(), client, k0sRole) + k0sToken, err := helmvm.GenerateAddNodeToken(r.Context(), client, k0sRole) + if err != nil { + logger.Error(fmt.Errorf("failed to generate add node token: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + k0sJoinCommand, err := helmvm.GenerateK0sJoinCommand(r.Context(), client, roles) if err != nil { - logger.Error(err) + logger.Error(fmt.Errorf("failed to generate k0s join command: %w", err)) w.WriteHeader(http.StatusInternalServerError) return } - JSON(w, http.StatusOK, GenerateHelmVMNodeJoinCommandResponse{ - Command: command, - Expiry: expiry.Format(time.RFC3339), + + clusterID, err := helmvm.ClusterID(client) + if err != nil { + logger.Error(fmt.Errorf("failed to get cluster id: %w", err)) + w.WriteHeader(http.StatusInternalServerError) + return + } + + JSON(w, http.StatusOK, GetK0sNodeJoinCommandResponse{ + ClusterID: clusterID, + K0sJoinCommand: k0sJoinCommand, + K0sToken: k0sToken, }) } diff --git a/pkg/handlers/interface.go b/pkg/handlers/interface.go index d98dd07e18..0a1fa800fe 100644 --- a/pkg/handlers/interface.go +++ b/pkg/handlers/interface.go @@ -139,7 +139,7 @@ type KOTSHandler interface { GetKurlNodes(w http.ResponseWriter, r *http.Request) // HelmVM - GenerateHelmVMNodeJoinCommand(w http.ResponseWriter, r *http.Request) + GenerateK0sNodeJoinCommand(w http.ResponseWriter, r *http.Request) DrainHelmVMNode(w http.ResponseWriter, r *http.Request) DeleteHelmVMNode(w http.ResponseWriter, r *http.Request) GetHelmVMNodes(w http.ResponseWriter, r *http.Request) diff --git a/pkg/handlers/mock/mock.go b/pkg/handlers/mock/mock.go index c186b8a9ac..0898775f70 100644 --- a/pkg/handlers/mock/mock.go +++ b/pkg/handlers/mock/mock.go @@ -443,15 +443,15 @@ func (mr *MockKOTSHandlerMockRecorder) GarbageCollectImages(w, r interface{}) *g } // GenerateHelmVMNodeJoinCommand mocks base method. -func (m *MockKOTSHandler) GenerateHelmVMNodeJoinCommand(w http.ResponseWriter, r *http.Request) { +func (m *MockKOTSHandler) GenerateK0sNodeJoinCommand(w http.ResponseWriter, r *http.Request) { m.ctrl.T.Helper() - m.ctrl.Call(m, "GenerateHelmVMNodeJoinCommand", w, r) + m.ctrl.Call(m, "GenerateK0sNodeJoinCommand", w, r) } // GenerateHelmVMNodeJoinCommand indicates an expected call of GenerateHelmVMNodeJoinCommand. func (mr *MockKOTSHandlerMockRecorder) GenerateHelmVMNodeJoinCommand(w, r interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateHelmVMNodeJoinCommand", reflect.TypeOf((*MockKOTSHandler)(nil).GenerateHelmVMNodeJoinCommand), w, r) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenerateK0sNodeJoinCommand", reflect.TypeOf((*MockKOTSHandler)(nil).GenerateK0sNodeJoinCommand), w, r) } // GenerateKurlNodeJoinCommandMaster mocks base method. diff --git a/pkg/helmvm/node_join.go b/pkg/helmvm/node_join.go index 4bbf1e197c..aa6544f044 100644 --- a/pkg/helmvm/node_join.go +++ b/pkg/helmvm/node_join.go @@ -3,63 +3,56 @@ package helmvm import ( "context" "fmt" + "os" "strings" "sync" "time" - "github.com/google/uuid" corev1 "k8s.io/api/core/v1" kuberneteserrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" ) -type joinCommandEntry struct { - Command []string +type joinTokenEntry struct { + Token string Creation *time.Time Mut sync.Mutex } -var joinCommandMapMut = sync.Mutex{} -var joinCommandMap = map[string]*joinCommandEntry{} +var joinTokenMapMut = sync.Mutex{} +var joinTokenMap = map[string]*joinTokenEntry{} -// GenerateAddNodeCommand will generate the HelmVM node add command for a primary or secondary node +// GenerateAddNodeToken will generate the HelmVM node add command for a primary or secondary node // join commands will last for 24 hours, and will be cached for 1 hour after first generation -func GenerateAddNodeCommand(ctx context.Context, client kubernetes.Interface, nodeRole string) ([]string, *time.Time, error) { - // get the joinCommand struct entry for this node role - joinCommandMapMut.Lock() - if _, ok := joinCommandMap[nodeRole]; !ok { - joinCommandMap[nodeRole] = &joinCommandEntry{} +func GenerateAddNodeToken(ctx context.Context, client kubernetes.Interface, nodeRole string) (string, error) { + // get the joinToken struct entry for this node role + joinTokenMapMut.Lock() + if _, ok := joinTokenMap[nodeRole]; !ok { + joinTokenMap[nodeRole] = &joinTokenEntry{} } - joinCommand := joinCommandMap[nodeRole] - joinCommandMapMut.Unlock() + joinToken := joinTokenMap[nodeRole] + joinTokenMapMut.Unlock() - // lock the joinCommand struct entry - joinCommand.Mut.Lock() - defer joinCommand.Mut.Unlock() + // lock the joinToken struct entry + joinToken.Mut.Lock() + defer joinToken.Mut.Unlock() - // if the joinCommand has been generated in the past hour, return it - if joinCommand.Creation != nil && time.Now().Before(joinCommand.Creation.Add(time.Hour)) { - expiry := joinCommand.Creation.Add(time.Hour * 24) - return joinCommand.Command, &expiry, nil + // if the joinToken has been generated in the past hour, return it + if joinToken.Creation != nil && time.Now().Before(joinToken.Creation.Add(time.Hour)) { + return joinToken.Token, nil } newToken, err := runAddNodeCommandPod(ctx, client, nodeRole) if err != nil { - return nil, nil, fmt.Errorf("failed to run add node command pod: %w", err) - } - - newCmd, err := generateAddNodeCommand(ctx, client, nodeRole, newToken) - if err != nil { - return nil, nil, fmt.Errorf("failed to generate add node command: %w", err) + return "", fmt.Errorf("failed to run add node command pod: %w", err) } now := time.Now() - joinCommand.Command = newCmd - joinCommand.Creation = &now + joinToken.Token = newToken + joinToken.Creation = &now - expiry := now.Add(time.Hour * 24) - return newCmd, &expiry, nil + return newToken, nil } // run a pod that will generate the add node token @@ -213,32 +206,87 @@ func runAddNodeCommandPod(ctx context.Context, client kubernetes.Interface, node return string(podLogs), nil } -// generate the add node command from the join token, the node roles, and info from the embedded-cluster-config configmap -func generateAddNodeCommand(ctx context.Context, client kubernetes.Interface, nodeRole string, token string) ([]string, error) { +// GenerateAddNodeCommand returns the command a user should run to add a node with the provided token +// the command will be of the form 'helmvm node join ip:port UUID' +func GenerateAddNodeCommand(ctx context.Context, client kubernetes.Interface, token string) (string, error) { cm, err := ReadConfigMap(client) if err != nil { - return nil, fmt.Errorf("failed to read configmap: %w", err) + return "", fmt.Errorf("failed to read configmap: %w", err) } - clusterID := cm.Data["embedded-cluster-id"] binaryName := cm.Data["embedded-binary-name"] - clusterUUID := uuid.UUID{} - err = clusterUUID.UnmarshalText([]byte(clusterID)) + // get the IP of a controller node + nodeIP, err := getControllerNodeIP(ctx, client) + if err != nil { + return "", fmt.Errorf("failed to get controller node IP: %w", err) + } + + // get the port of the 'admin-console' service + port, err := getAdminConsolePort(ctx, client) + if err != nil { + return "", fmt.Errorf("failed to get admin console port: %w", err) + } + + return fmt.Sprintf("%s node join %s:%d %s", binaryName, nodeIP, port, token), nil +} + +// GenerateK0sJoinCommand returns the k0s node join command, without the token but with all other required flags +// (including node labels generated from the roles etc) +func GenerateK0sJoinCommand(ctx context.Context, client kubernetes.Interface, roles []string) (string, error) { + k0sRole := "worker" + for _, role := range roles { + if role == "controller" { + k0sRole = "controller" + } + } + + cmd := []string{"/usr/local/bin/k0s", "install", k0sRole, "--force"} + if k0sRole == "controller" { + cmd = append(cmd, "--enable-worker") + } + + return strings.Join(cmd, " "), nil +} + +// gets the port of the 'admin-console' service +func getAdminConsolePort(ctx context.Context, client kubernetes.Interface) (int32, error) { + svc, err := client.CoreV1().Services(os.Getenv("POD_NAMESPACE")).Get(ctx, "admin-console", metav1.GetOptions{}) if err != nil { - return nil, fmt.Errorf("failed to unmarshal cluster id %s: %w", clusterID, err) + return -1, fmt.Errorf("failed to get admin-console service: %w", err) } - fullToken := joinToken{ - ClusterID: clusterUUID, - Token: token, - Role: nodeRole, + for _, port := range svc.Spec.Ports { + if port.Name == "http" { + return port.NodePort, nil + } } + return -1, fmt.Errorf("did not find port 'http' in service 'admin-console'") +} - b64token, err := fullToken.Encode() +// getControllerNodeIP gets the IP of a healthy controller node +func getControllerNodeIP(ctx context.Context, client kubernetes.Interface) (string, error) { + nodes, err := client.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) if err != nil { - return nil, fmt.Errorf("unable to encode token: %w", err) + return "", fmt.Errorf("failed to list nodes: %w", err) + } + + for _, node := range nodes.Items { + if cp, ok := node.Labels["node-role.kubernetes.io/control-plane"]; !ok || cp != "true" { + continue + } + + for _, condition := range node.Status.Conditions { + if condition.Type == "Ready" && condition.Status == "True" { + for _, address := range node.Status.Addresses { + if address.Type == "InternalIP" { + return address.Address, nil + } + } + } + } + } - return []string{binaryName + " node join", b64token}, nil + return "", fmt.Errorf("failed to find healthy controller node") } diff --git a/pkg/helmvm/util.go b/pkg/helmvm/util.go index ce358abab0..65d93f0c9f 100644 --- a/pkg/helmvm/util.go +++ b/pkg/helmvm/util.go @@ -43,3 +43,12 @@ func IsHelmVM(clientset kubernetes.Interface) (bool, error) { func IsHA(clientset kubernetes.Interface) (bool, error) { return true, nil } + +func ClusterID(client kubernetes.Interface) (string, error) { + configMap, err := ReadConfigMap(client) + if err != nil { + return "", fmt.Errorf("failed to read configmap: %w", err) + } + + return configMap.Data["embedded-cluster-id"], nil +} diff --git a/pkg/store/kotsstore/k0s_store.go b/pkg/store/kotsstore/k0s_store.go new file mode 100644 index 0000000000..61c717dd0b --- /dev/null +++ b/pkg/store/kotsstore/k0s_store.go @@ -0,0 +1,68 @@ +package kotsstore + +import ( + "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/replicatedhq/kots/pkg/persistence" + "github.com/rqlite/gorqlite" +) + +func (s *KOTSStore) SetK0sInstallCommandRoles(roles []string) (string, error) { + db := persistence.MustGetDBSession() + + installID := uuid.New().String() + + query := `delete from k0s_tokens where token = ?` + wr, err := db.WriteOneParameterized(gorqlite.ParameterizedStatement{ + Query: query, + Arguments: []interface{}{installID}, + }) + if err != nil { + return "", fmt.Errorf("delete k0s join token: %v: %v", err, wr.Err) + } + + jsonRoles, err := json.Marshal(roles) + if err != nil { + return "", fmt.Errorf("failed to marshal roles: %w", err) + } + + query = `insert into k0s_tokens (token, roles) values (?, ?)` + wr, err = db.WriteOneParameterized(gorqlite.ParameterizedStatement{ + Query: query, + Arguments: []interface{}{installID, string(jsonRoles)}, + }) + if err != nil { + return "", fmt.Errorf("insert k0s join token: %v: %v", err, wr.Err) + } + + return installID, nil +} + +func (s *KOTSStore) GetK0sInstallCommandRoles(token string) ([]string, error) { + db := persistence.MustGetDBSession() + query := `select roles from k0s_tokens where token = ?` + rows, err := db.QueryOneParameterized(gorqlite.ParameterizedStatement{ + Query: query, + Arguments: []interface{}{token}, + }) + if err != nil { + return nil, fmt.Errorf("failed to query: %v: %v", err, rows.Err) + } + if !rows.Next() { + return nil, ErrNotFound + } + + rolesStr := "" + if err = rows.Scan(&rolesStr); err != nil { + return nil, fmt.Errorf("failed to scan roles: %w", err) + } + + rolesArr := []string{} + err = json.Unmarshal([]byte(rolesStr), &rolesArr) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal roles: %w", err) + } + + return rolesArr, nil +}