From 8129b2c1f6558443d34a84f0b6032bdb936970e3 Mon Sep 17 00:00:00 2001 From: Jack Shaw Date: Thu, 21 Nov 2024 15:57:57 +0000 Subject: [PATCH] refactor: ports schema to hold RI with charm_relation At the time, remove the unit_endpoint table. This table wasn't providing us with anything but pain. Instead, attach the foriegn keys (unit and charm_relation) directly to the port range This then required a fairly substantial refactor of the state layer. Introducing RI from ports to relation resulted in an additional complexity with the wildcard endpoint, as this does not have a corresponding charm relation. I have dealt with this by using the null value. So a null value for relation_uuid represents the wildcard endpoint. This turned out to be a fairly natural solution. This refactor also meant that the watcher now emits unit uuids, rather than unit_endpoint uuids. This meant the watcher code needed a refactor/simplification as well. --- apiserver/facades/client/client/api_test.go | 1 + .../facades/client/client/filtering_test.go | 8 +- .../facades/client/client/status_test.go | 4 +- domain/application/state/application.go | 15 +- domain/application/state/application_test.go | 19 +- domain/port/errors/errors.go | 4 + domain/port/service/package_mock_test.go | 48 ++--- domain/port/service/watcher.go | 28 +-- domain/port/state/state.go | 170 +++++++----------- domain/port/state/state_test.go | 32 +++- domain/port/state/types.go | 22 +-- domain/port/state/watcher.go | 41 ++--- domain/port/state/watcher_test.go | 70 +------- domain/port/watcher_test.go | 76 ++++---- domain/schema/model.go | 3 +- domain/schema/model/sql/0021-port-ranges.sql | 50 +++--- .../schema/model/triggers/port-range.gen.go | 7 +- domain/schema/schema_test.go | 2 +- .../charm-repo/quantal/mysql/metadata.yaml | 2 + 19 files changed, 273 insertions(+), 329 deletions(-) diff --git a/apiserver/facades/client/client/api_test.go b/apiserver/facades/client/client/api_test.go index e153cf2867b..de55d490fbb 100644 --- a/apiserver/facades/client/client/api_test.go +++ b/apiserver/facades/client/client/api_test.go @@ -228,6 +228,7 @@ var scenarioStatus = ¶ms.FullStatus{ "": network.AlphaSpaceName, "server": network.AlphaSpaceName, "server-admin": network.AlphaSpaceName, + "db": network.AlphaSpaceName, "db-router": network.AlphaSpaceName, "metrics-client": network.AlphaSpaceName, }, diff --git a/apiserver/facades/client/client/filtering_test.go b/apiserver/facades/client/client/filtering_test.go index 34ef2df4604..d2d701df294 100644 --- a/apiserver/facades/client/client/filtering_test.go +++ b/apiserver/facades/client/client/filtering_test.go @@ -240,14 +240,14 @@ func (s *filteringStatusSuite) TestFilterByPortRange(c *gc.C) { portService := s.ControllerDomainServices(c).Port() err = portService.UpdateUnitPorts(context.Background(), unit0UUID, network.GroupedPortRanges{ - "": []network.PortRange{network.MustParsePortRange("1000/tcp")}, - "foo": []network.PortRange{network.MustParsePortRange("2000/tcp")}, + "": []network.PortRange{network.MustParsePortRange("1000/tcp")}, + "db": []network.PortRange{network.MustParsePortRange("2000/tcp")}, }, network.GroupedPortRanges{}) c.Assert(err, jc.ErrorIsNil) err = portService.UpdateUnitPorts(context.Background(), unit1UUID, network.GroupedPortRanges{ - "": []network.PortRange{network.MustParsePortRange("2000/tcp")}, - "bar": []network.PortRange{network.MustParsePortRange("3000/tcp")}, + "": []network.PortRange{network.MustParsePortRange("2000/tcp")}, + "cache": []network.PortRange{network.MustParsePortRange("3000/tcp")}, }, network.GroupedPortRanges{}) c.Assert(err, jc.ErrorIsNil) diff --git a/apiserver/facades/client/client/status_test.go b/apiserver/facades/client/client/status_test.go index feff038a87f..4bdc9f79165 100644 --- a/apiserver/facades/client/client/status_test.go +++ b/apiserver/facades/client/client/status_test.go @@ -851,8 +851,8 @@ func (s *statusUnitTestSuite) TestUnitsWithOpenedPortsSent(c *gc.C) { portService := s.ControllerDomainServices(c).Port() err = portService.UpdateUnitPorts(context.Background(), unitUUID, network.GroupedPortRanges{ - "": []network.PortRange{network.MustParsePortRange("1000/tcp")}, - "foo": []network.PortRange{network.MustParsePortRange("2000/tcp")}, + "": []network.PortRange{network.MustParsePortRange("1000/tcp")}, + "db": []network.PortRange{network.MustParsePortRange("2000/tcp")}, }, network.GroupedPortRanges{}) c.Assert(err, jc.ErrorIsNil) diff --git a/domain/application/state/application.go b/domain/application/state/application.go index 6a1bda19569..46b96d5a296 100644 --- a/domain/application/state/application.go +++ b/domain/application/state/application.go @@ -1023,10 +1023,7 @@ func (st *ApplicationState) deletePorts(ctx context.Context, tx *sqlair.TX, unit deletePortRange := ` DELETE FROM port_range -WHERE unit_endpoint_uuid IN ( - SELECT uuid FROM unit_endpoint ue - WHERE ue.unit_uuid = $minimalUnit.uuid -) +WHERE unit_uuid = $minimalUnit.uuid ` deletePortRangeStmt, err := st.Prepare(deletePortRange, unit) if err != nil { @@ -1037,15 +1034,6 @@ WHERE unit_endpoint_uuid IN ( return errors.Annotate(err, "cannot delete port range records") } - deleteEndpoint := `DELETE FROM unit_endpoint WHERE unit_uuid = $minimalUnit.uuid` - deleteEndpointStmt, err := st.Prepare(deleteEndpoint, unit) - if err != nil { - return errors.Annotate(err, "cannot delete endpoint records") - } - - if err := tx.Query(ctx, deleteEndpointStmt, unit).Run(); err != nil { - return errors.Trace(err) - } return nil } @@ -1063,7 +1051,6 @@ func (st *ApplicationState) deleteSimpleUnitReferences(ctx context.Context, tx * "unit_workload_status", "cloud_container_status_data", "cloud_container_status", - "unit_endpoint", } { deleteUnitReference := fmt.Sprintf(`DELETE FROM %s WHERE unit_uuid = $minimalUnit.uuid`, table) deleteUnitReferenceStmt, err := st.Prepare(deleteUnitReference, unit) diff --git a/domain/application/state/application_test.go b/domain/application/state/application_test.go index b388071e044..44df07091d2 100644 --- a/domain/application/state/application_test.go +++ b/domain/application/state/application_test.go @@ -680,7 +680,6 @@ func (s *applicationStateSuite) TestDeleteUnit(c *gc.C) { deviceCount int addressCount int portCount int - endpointCount int agentStatusCount int agentStatusDataCount int workloadStatusCount int @@ -704,9 +703,6 @@ func (s *applicationStateSuite) TestDeleteUnit(c *gc.C) { if err := tx.QueryRowContext(ctx, "SELECT count(*) FROM cloud_container_port WHERE cloud_container_uuid=?", netNodeUUID).Scan(&portCount); err != nil { return err } - if err := tx.QueryRowContext(ctx, "SELECT count(*) FROM unit_endpoint WHERE unit_uuid=?", unitUUID).Scan(&endpointCount); err != nil { - return err - } if err := tx.QueryRowContext(ctx, "SELECT count(*) FROM unit_agent_status WHERE unit_uuid=?", unitUUID).Scan(&agentStatusCount); err != nil { return err } @@ -730,7 +726,6 @@ func (s *applicationStateSuite) TestDeleteUnit(c *gc.C) { c.Assert(err, jc.ErrorIsNil) c.Assert(addressCount, gc.Equals, 0) c.Assert(portCount, gc.Equals, 0) - c.Assert(endpointCount, gc.Equals, 0) c.Assert(deviceCount, gc.Equals, 0) c.Assert(containerCount, gc.Equals, 0) c.Assert(agentStatusCount, gc.Equals, 0) @@ -1945,6 +1940,20 @@ func (s *applicationStateSuite) createApplication(c *gc.C, name string, l life.L Charm: charm.Charm{ Metadata: charm.Metadata{ Name: name, + Provides: map[string]charm.Relation{ + "endpoint": { + Name: "endpoint", + Key: "endpoint", + Role: charm.RoleProvider, + Scope: charm.ScopeGlobal, + }, + "misc": { + Name: "misc", + Key: "misc", + Role: charm.RoleProvider, + Scope: charm.ScopeGlobal, + }, + }, }, }, Origin: charm.CharmOrigin{ diff --git a/domain/port/errors/errors.go b/domain/port/errors/errors.go index ee1c60e7d3e..7e9af33b4bb 100644 --- a/domain/port/errors/errors.go +++ b/domain/port/errors/errors.go @@ -15,4 +15,8 @@ const ( // PortRangeConflict describes an error that occurs when a user tries to open // or close a port range overlaps with another. PortRangeConflict = errors.ConstError("port range conflict") + + // InvalidEndpoint describes an error that occurs when a user trying to open + // or close a port range with an endpoint which does not exist on the unit. + InvalidEndpoint = errors.ConstError("invalid endpoint(s)") ) diff --git a/domain/port/service/package_mock_test.go b/domain/port/service/package_mock_test.go index 853c4a9e804..5c6e04e63ba 100644 --- a/domain/port/service/package_mock_test.go +++ b/domain/port/service/package_mock_test.go @@ -46,41 +46,41 @@ func (m *MockState) EXPECT() *MockStateMockRecorder { return m.recorder } -// FilterEndpointsForApplication mocks base method. -func (m *MockState) FilterEndpointsForApplication(arg0 context.Context, arg1 []string, arg2 application.ID) (set.Strings, error) { +// FilterUnitUUIDsForApplication mocks base method. +func (m *MockState) FilterUnitUUIDsForApplication(arg0 context.Context, arg1 []unit.UUID, arg2 application.ID) (set.Strings, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FilterEndpointsForApplication", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "FilterUnitUUIDsForApplication", arg0, arg1, arg2) ret0, _ := ret[0].(set.Strings) ret1, _ := ret[1].(error) return ret0, ret1 } -// FilterEndpointsForApplication indicates an expected call of FilterEndpointsForApplication. -func (mr *MockStateMockRecorder) FilterEndpointsForApplication(arg0, arg1, arg2 any) *MockStateFilterEndpointsForApplicationCall { +// FilterUnitUUIDsForApplication indicates an expected call of FilterUnitUUIDsForApplication. +func (mr *MockStateMockRecorder) FilterUnitUUIDsForApplication(arg0, arg1, arg2 any) *MockStateFilterUnitUUIDsForApplicationCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterEndpointsForApplication", reflect.TypeOf((*MockState)(nil).FilterEndpointsForApplication), arg0, arg1, arg2) - return &MockStateFilterEndpointsForApplicationCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterUnitUUIDsForApplication", reflect.TypeOf((*MockState)(nil).FilterUnitUUIDsForApplication), arg0, arg1, arg2) + return &MockStateFilterUnitUUIDsForApplicationCall{Call: call} } -// MockStateFilterEndpointsForApplicationCall wrap *gomock.Call -type MockStateFilterEndpointsForApplicationCall struct { +// MockStateFilterUnitUUIDsForApplicationCall wrap *gomock.Call +type MockStateFilterUnitUUIDsForApplicationCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockStateFilterEndpointsForApplicationCall) Return(arg0 set.Strings, arg1 error) *MockStateFilterEndpointsForApplicationCall { +func (c *MockStateFilterUnitUUIDsForApplicationCall) Return(arg0 set.Strings, arg1 error) *MockStateFilterUnitUUIDsForApplicationCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockStateFilterEndpointsForApplicationCall) Do(f func(context.Context, []string, application.ID) (set.Strings, error)) *MockStateFilterEndpointsForApplicationCall { +func (c *MockStateFilterUnitUUIDsForApplicationCall) Do(f func(context.Context, []unit.UUID, application.ID) (set.Strings, error)) *MockStateFilterUnitUUIDsForApplicationCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateFilterEndpointsForApplicationCall) DoAndReturn(f func(context.Context, []string, application.ID) (set.Strings, error)) *MockStateFilterEndpointsForApplicationCall { +func (c *MockStateFilterUnitUUIDsForApplicationCall) DoAndReturn(f func(context.Context, []unit.UUID, application.ID) (set.Strings, error)) *MockStateFilterUnitUUIDsForApplicationCall { c.Call = c.Call.DoAndReturn(f) return c } @@ -280,41 +280,41 @@ func (c *MockStateGetEndpointsCall) DoAndReturn(f func(domain.AtomicContext, uni return c } -// GetMachineNamesForUnitEndpoints mocks base method. -func (m *MockState) GetMachineNamesForUnitEndpoints(arg0 context.Context, arg1 []string) ([]machine.Name, error) { +// GetMachineNamesForUnits mocks base method. +func (m *MockState) GetMachineNamesForUnits(arg0 context.Context, arg1 []unit.UUID) ([]machine.Name, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetMachineNamesForUnitEndpoints", arg0, arg1) + ret := m.ctrl.Call(m, "GetMachineNamesForUnits", arg0, arg1) ret0, _ := ret[0].([]machine.Name) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetMachineNamesForUnitEndpoints indicates an expected call of GetMachineNamesForUnitEndpoints. -func (mr *MockStateMockRecorder) GetMachineNamesForUnitEndpoints(arg0, arg1 any) *MockStateGetMachineNamesForUnitEndpointsCall { +// GetMachineNamesForUnits indicates an expected call of GetMachineNamesForUnits. +func (mr *MockStateMockRecorder) GetMachineNamesForUnits(arg0, arg1 any) *MockStateGetMachineNamesForUnitsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMachineNamesForUnitEndpoints", reflect.TypeOf((*MockState)(nil).GetMachineNamesForUnitEndpoints), arg0, arg1) - return &MockStateGetMachineNamesForUnitEndpointsCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMachineNamesForUnits", reflect.TypeOf((*MockState)(nil).GetMachineNamesForUnits), arg0, arg1) + return &MockStateGetMachineNamesForUnitsCall{Call: call} } -// MockStateGetMachineNamesForUnitEndpointsCall wrap *gomock.Call -type MockStateGetMachineNamesForUnitEndpointsCall struct { +// MockStateGetMachineNamesForUnitsCall wrap *gomock.Call +type MockStateGetMachineNamesForUnitsCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockStateGetMachineNamesForUnitEndpointsCall) Return(arg0 []machine.Name, arg1 error) *MockStateGetMachineNamesForUnitEndpointsCall { +func (c *MockStateGetMachineNamesForUnitsCall) Return(arg0 []machine.Name, arg1 error) *MockStateGetMachineNamesForUnitsCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockStateGetMachineNamesForUnitEndpointsCall) Do(f func(context.Context, []string) ([]machine.Name, error)) *MockStateGetMachineNamesForUnitEndpointsCall { +func (c *MockStateGetMachineNamesForUnitsCall) Do(f func(context.Context, []unit.UUID) ([]machine.Name, error)) *MockStateGetMachineNamesForUnitsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockStateGetMachineNamesForUnitEndpointsCall) DoAndReturn(f func(context.Context, []string) ([]machine.Name, error)) *MockStateGetMachineNamesForUnitEndpointsCall { +func (c *MockStateGetMachineNamesForUnitsCall) DoAndReturn(f func(context.Context, []unit.UUID) ([]machine.Name, error)) *MockStateGetMachineNamesForUnitsCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/domain/port/service/watcher.go b/domain/port/service/watcher.go index 11e52df52d5..6dd79c9f4f8 100644 --- a/domain/port/service/watcher.go +++ b/domain/port/service/watcher.go @@ -14,6 +14,7 @@ import ( "github.com/juju/juju/core/database" "github.com/juju/juju/core/logger" coremachine "github.com/juju/juju/core/machine" + "github.com/juju/juju/core/unit" "github.com/juju/juju/core/watcher" "github.com/juju/juju/core/watcher/eventsource" ) @@ -61,13 +62,13 @@ type WatcherState interface { // event for the WatchMachineOpenedPorts watcher InitialWatchMachineOpenedPortsStatement() string - // GetMachineNamesForUnitEndpoints returns map from endpoint uuids to the uuids of + // GetMachineNamesForUnits returns map from endpoint uuids to the uuids of // the machines which host that endpoint for each provided endpoint uuid. - GetMachineNamesForUnitEndpoints(ctx context.Context, endpointUUIDs []string) ([]coremachine.Name, error) + GetMachineNamesForUnits(context.Context, []unit.UUID) ([]coremachine.Name, error) // FilterEndpointsForApplication returns the subset of provided endpoint uuids // that are associated with the provided application. - FilterEndpointsForApplication(ctx context.Context, eps []string, app coreapplication.ID) (set.Strings, error) + FilterUnitUUIDsForApplication(context.Context, []unit.UUID, coreapplication.ID) (set.Strings, error) } // WatchMachineOpenedPorts returns a strings watcher for opened ports. This watcher @@ -101,11 +102,14 @@ func (s *WatchableService) endpointToMachineMapper( ctx context.Context, db database.TxnRunner, events []changestream.ChangeEvent, ) ([]changestream.ChangeEvent, error) { - endpointUUIDs := transform.Slice(events, func(e changestream.ChangeEvent) string { - return e.Changed() + unitUUIDs, err := transform.SliceOrErr(events, func(e changestream.ChangeEvent) (unit.UUID, error) { + return unit.ParseID(e.Changed()) }) + if err != nil { + return nil, err + } - machineNames, err := s.st.GetMachineNamesForUnitEndpoints(ctx, endpointUUIDs) + machineNames, err := s.st.GetMachineNamesForUnits(ctx, unitUUIDs) if err != nil { return nil, err } @@ -129,16 +133,20 @@ func (s *WatchableService) filterForApplication(applicationUUID coreapplication. return func( ctx context.Context, db database.TxnRunner, events []changestream.ChangeEvent, ) ([]changestream.ChangeEvent, error) { - endpointUUIDs := transform.Slice(events, func(e changestream.ChangeEvent) string { - return e.Changed() + unitUUIDs, err := transform.SliceOrErr(events, func(e changestream.ChangeEvent) (unit.UUID, error) { + return unit.ParseID(e.Changed()) }) - endpointUUIDsForApplication, err := s.st.FilterEndpointsForApplication(ctx, endpointUUIDs, applicationUUID) + if err != nil { + return nil, err + } + + unitUUIDsForApplication, err := s.st.FilterUnitUUIDsForApplication(ctx, unitUUIDs, applicationUUID) if err != nil { return nil, err } results := make([]changestream.ChangeEvent, 0, len(events)) for _, event := range events { - if endpointUUIDsForApplication.Contains(event.Changed()) { + if unitUUIDsForApplication.Contains(event.Changed()) { results = append(results, event) } } diff --git a/domain/port/state/state.go b/domain/port/state/state.go index 5c4b84d7a4b..776f31e7879 100644 --- a/domain/port/state/state.go +++ b/domain/port/state/state.go @@ -46,10 +46,8 @@ func (st *State) GetUnitOpenedPorts(ctx context.Context, unit coreunit.UUID) (ne query, err := st.Prepare(` SELECT &endpointPortRange.* -FROM port_range -JOIN protocol ON port_range.protocol_id = protocol.id -JOIN unit_endpoint ON port_range.unit_endpoint_uuid = unit_endpoint.uuid -WHERE unit_endpoint.unit_uuid = $unitUUID.unit_uuid +FROM v_port_range +WHERE unit_uuid = $unitUUID.unit_uuid `, endpointPortRange{}, unitUUID) if err != nil { return nil, errors.Errorf("preparing get unit opened ports statement: %w", err) @@ -95,10 +93,8 @@ func (s *State) GetAllOpenedPorts(ctx context.Context) (port.UnitGroupedPortRang query, err := s.Prepare(` SELECT DISTINCT &unitNamePortRange.* -FROM port_range -JOIN protocol ON port_range.protocol_id = protocol.id -JOIN unit_endpoint ON port_range.unit_endpoint_uuid = unit_endpoint.uuid -JOIN unit ON unit_endpoint.unit_uuid = unit.uuid +FROM v_port_range +JOIN unit ON unit_uuid = unit.uuid `, unitNamePortRange{}) if err != nil { return nil, errors.Errorf("preparing get all opened ports statement: %w", err) @@ -146,8 +142,9 @@ func (st *State) GetMachineOpenedPorts(ctx context.Context, machine string) (map query, err := st.Prepare(` SELECT &unitEndpointPortRange.* -FROM v_port_range AS pr -JOIN machine ON pr.net_node_uuid = machine.net_node_uuid +FROM v_port_range +JOIN unit ON unit_uuid = unit.uuid +JOIN machine ON unit.net_node_uuid = machine.net_node_uuid WHERE machine.uuid = $machineUUID.machine_uuid `, unitEndpointPortRange{}, machineUUID) if err != nil { @@ -186,6 +183,7 @@ func (st *State) GetApplicationOpenedPorts(ctx context.Context, application core query, err := st.Prepare(` SELECT &unitEndpointPortRange.* FROM v_port_range +JOIN unit ON unit_uuid = unit.uuid WHERE application_uuid = $applicationUUID.application_uuid `, unitEndpointPortRange{}, applicationUUID) if err != nil { @@ -218,10 +216,8 @@ func (st *State) GetColocatedOpenedPorts(ctx domain.AtomicContext, unit coreunit getOpenedPorts, err := st.Prepare(` SELECT &portRange.* -FROM port_range AS pr -JOIN protocol AS p ON pr.protocol_id = p.id -JOIN unit_endpoint AS ep ON pr.unit_endpoint_uuid = ep.uuid -JOIN unit AS u ON ep.unit_uuid = u.uuid +FROM v_port_range AS pr +JOIN unit AS u ON unit_uuid = u.uuid JOIN unit AS u2 on u2.net_node_uuid = u.net_node_uuid WHERE u2.uuid = $unitUUID.unit_uuid `, portRange{}, unitUUID) @@ -241,7 +237,7 @@ WHERE u2.uuid = $unitUUID.unit_uuid return nil, errors.Errorf("getting opened ports for colocated units with %q: %w", unit, err) } - ret := transform.Slice(portRanges, func(p portRange) network.PortRange { return p.decode() }) + ret := transform.Slice(portRanges, portRange.decode) network.SortPortRanges(ret) return ret, nil } @@ -254,11 +250,9 @@ func (st *State) GetEndpointOpenedPorts(ctx domain.AtomicContext, unit coreunit. query, err := st.Prepare(` SELECT &portRange.* -FROM port_range -JOIN protocol ON port_range.protocol_id = protocol.id -JOIN unit_endpoint ON port_range.unit_endpoint_uuid = unit_endpoint.uuid -WHERE unit_endpoint.unit_uuid = $unitUUID.unit_uuid -AND unit_endpoint.endpoint = $endpointName.endpoint +FROM v_port_range +WHERE unit_uuid = $unitUUID.unit_uuid +AND endpoint = $endpointName.endpoint `, portRange{}, unitUUID, endpointName) if err != nil { return nil, errors.Errorf("preparing get endpoint opened ports statement: %w", err) @@ -297,14 +291,15 @@ func (st *State) UpdateUnitPorts( for endpoint := range closePorts { endpointsUnderActionSet.Add(endpoint) } + endpointsUnderActionSet.Remove(port.WildcardEndpoint) endpointsUnderAction := endpoints(endpointsUnderActionSet.Values()) unitUUID := unitUUID{UUID: unit} return domain.Run(ctx, func(ctx context.Context, tx *sqlair.TX) error { - endpoints, err := st.ensureEndpoints(ctx, tx, unitUUID, endpointsUnderAction) + endpoints, err := st.lookupRelationUUIDs(ctx, tx, unitUUID, endpointsUnderAction) if err != nil { - return errors.Errorf("ensuring endpoints exist for unit %q: %w", unit, err) + return errors.Errorf("looking up relation endpoint uuids for unit %q: %w", unit, err) } currentUnitOpenedPorts, err := st.getUnitOpenedPorts(ctx, tx, unitUUID) @@ -312,7 +307,7 @@ func (st *State) UpdateUnitPorts( return errors.Errorf("getting opened ports for unit %q: %w", unit, err) } - err = st.openPorts(ctx, tx, openPorts, currentUnitOpenedPorts, endpoints) + err = st.openPorts(ctx, tx, openPorts, currentUnitOpenedPorts, unitUUID, endpoints) if err != nil { return errors.Errorf("opening ports for unit %q: %w", unit, err) } @@ -326,64 +321,34 @@ func (st *State) UpdateUnitPorts( }) } -// ensureEndpoints ensures that the given endpoints are present in the database. -// Return all endpoints under action with their corresponding UUIDs. -// -// TODO(jack-w-shaw): Once it has been implemented, we should verify new endpoints -// are valid by checking the charm_relation table. -func (st *State) ensureEndpoints( +func (st *State) lookupRelationUUIDs( ctx context.Context, tx *sqlair.TX, unitUUID unitUUID, endpointsUnderAction endpoints, ) ([]endpoint, error) { - getUnitEndpoints, err := st.Prepare(` + getEndpoints, err := st.Prepare(` SELECT &endpoint.* -FROM unit_endpoint +FROM v_endpoint WHERE unit_uuid = $unitUUID.unit_uuid AND endpoint IN ($endpoints[:]) `, endpoint{}, unitUUID, endpointsUnderAction) if err != nil { - return nil, errors.Errorf("preparing get unit endpoints statement: %w", err) - } - - insertUnitEndpoint, err := st.Prepare("INSERT INTO unit_endpoint (*) VALUES ($unitEndpoint.*)", unitEndpoint{}) - if err != nil { - return nil, errors.Errorf("preparing insert unit endpoint statement: %w", err) + return nil, errors.Errorf("preparing get endpoints statement: %w", err) } - var endpoints []endpoint - err = tx.Query(ctx, getUnitEndpoints, unitUUID, endpointsUnderAction).GetAll(&endpoints) + endpoints := []endpoint{} + err = tx.Query(ctx, getEndpoints, unitUUID, endpointsUnderAction).GetAll(&endpoints) if err != nil && !errors.Is(err, sqlair.ErrNoRows) { return nil, errors.Capture(err) } - foundEndpoints := set.NewStrings() - for _, ep := range endpoints { - foundEndpoints.Add(ep.Endpoint) - } - - // Insert any new endpoints that are required. - requiredEndpoints := set.NewStrings(endpointsUnderAction...).Difference(foundEndpoints) - newUnitEndpoints := make([]unitEndpoint, requiredEndpoints.Size()) - for i, requiredEndpoint := range requiredEndpoints.Values() { - uuid, err := uuid.NewUUID() - if err != nil { - return nil, errors.Errorf("generating UUID for unit endpoint: %w", err) - } - newUnitEndpoints[i] = unitEndpoint{ - UUID: uuid.String(), - UnitUUID: unitUUID.UUID, - Endpoint: requiredEndpoint, - } - endpoints = append(endpoints, endpoint{ - Endpoint: requiredEndpoint, - UUID: uuid.String(), - }) - } - - if len(newUnitEndpoints) > 0 { - err = tx.Query(ctx, insertUnitEndpoint, newUnitEndpoints).Run() - if err != nil { - return nil, errors.Capture(err) + if len(endpoints) != len(endpointsUnderAction) { + endpointsSet := set.NewStrings([]string(endpointsUnderAction)...) + for _, ep := range endpoints { + endpointsSet.Remove(ep.Endpoint) } + return nil, errors.Errorf( + "%w; %v does exist on unit %v", + porterrors.InvalidEndpoint, endpointsSet.Values(), unitUUID.UUID, + ) } return endpoints, nil @@ -395,16 +360,9 @@ AND endpoint IN ($endpoints[:]) // their UUIDs, which are not needed by GetUnitOpenedPorts. func (st *State) getUnitOpenedPorts(ctx context.Context, tx *sqlair.TX, unitUUID unitUUID) ([]endpointPortRangeUUID, error) { getOpenedPorts, err := st.Prepare(` -SELECT - port_range.uuid AS &endpointPortRangeUUID.uuid, - protocol.protocol AS &endpointPortRangeUUID.protocol, - port_range.from_port AS &endpointPortRangeUUID.from_port, - port_range.to_port AS &endpointPortRangeUUID.to_port, - unit_endpoint.endpoint AS &endpointPortRangeUUID.endpoint -FROM port_range -JOIN protocol ON port_range.protocol_id = protocol.id -JOIN unit_endpoint ON port_range.unit_endpoint_uuid = unit_endpoint.uuid -WHERE unit_endpoint.unit_uuid = $unitUUID.unit_uuid +SELECT &endpointPortRangeUUID.* +FROM v_port_range +WHERE unit_uuid = $unitUUID.unit_uuid `, endpointPortRangeUUID{}, unitUUID) if err != nil { return nil, errors.Errorf("preparing get opened ports statement: %w", err) @@ -424,7 +382,7 @@ WHERE unit_endpoint.unit_uuid = $unitUUID.unit_uuid // openPorts inserts the given port ranges into the database, unless they're already open. func (st *State) openPorts( ctx context.Context, tx *sqlair.TX, - openPorts network.GroupedPortRanges, currentOpenedPorts []endpointPortRangeUUID, endpoints []endpoint, + openPorts network.GroupedPortRanges, currentOpenedPorts []endpointPortRangeUUID, unitUUID unitUUID, endpoints []endpoint, ) error { insertPortRange, err := st.Prepare("INSERT INTO port_range (*) VALUES ($unitPortRange.*)", unitPortRange{}) if err != nil { @@ -436,6 +394,12 @@ func (st *State) openPorts( return errors.Errorf("getting protocol map: %w", err) } + // construct a map from endpoint name to it's UUID. + endpointUUIDMaps := make(map[string]string) + for _, ep := range endpoints { + endpointUUIDMaps[ep.Endpoint] = ep.UUID + } + // index the current opened ports by endpoint and port range currentOpenedPortRangeExistenceIndex := make(map[string]map[network.PortRange]bool) for _, openedPortRange := range currentOpenedPorts { @@ -445,41 +409,35 @@ func (st *State) openPorts( currentOpenedPortRangeExistenceIndex[openedPortRange.Endpoint][openedPortRange.decode()] = true } - // Construct the new port ranges to open - var openPortRanges []unitPortRange - - for _, ep := range endpoints { - ports, ok := openPorts[ep.Endpoint] - if !ok { - continue - } - + for ep, ports := range openPorts { for _, portRange := range ports { // skip port range if it's already open on this endpoint - if _, ok := currentOpenedPortRangeExistenceIndex[ep.Endpoint][portRange]; ok { + if _, ok := currentOpenedPortRangeExistenceIndex[ep][portRange]; ok { continue } - uuid, err := uuid.NewUUID() if err != nil { return errors.Errorf("generating UUID for port range: %w", err) } - openPortRanges = append(openPortRanges, unitPortRange{ - UUID: uuid.String(), - ProtocolID: protocolMap[portRange.Protocol], - FromPort: portRange.FromPort, - ToPort: portRange.ToPort, - UnitEndpointUUID: ep.UUID, - }) + var relationUUID string + if ep != port.WildcardEndpoint { + relationUUID = endpointUUIDMaps[ep] + } + unitPortRange := unitPortRange{ + UUID: uuid.String(), + ProtocolID: protocolMap[portRange.Protocol], + FromPort: portRange.FromPort, + ToPort: portRange.ToPort, + UnitUUID: unitUUID.UUID, + RelationUUID: relationUUID, + } + err = tx.Query(ctx, insertPortRange, unitPortRange).Run() + if err != nil { + return errors.Capture(err) + } } } - if len(openPortRanges) > 0 { - err = tx.Query(ctx, insertPortRange, openPortRanges).Run() - if err != nil { - return errors.Capture(err) - } - } return nil } @@ -557,11 +515,9 @@ func (st *State) GetEndpoints(ctx domain.AtomicContext, unit coreunit.UUID) ([]s unitUUID := unitUUID{UUID: unit} getEndpoints, err := st.Prepare(` -SELECT charm_relation.key AS &endpointName.endpoint -FROM unit -JOIN application ON unit.application_uuid = application.uuid -JOIN charm_relation ON application.charm_uuid = charm_relation.charm_uuid -WHERE unit.uuid = $unitUUID.unit_uuid +SELECT &endpointName.* +FROM v_endpoint +WHERE unit_uuid = $unitUUID.unit_uuid `, endpointName{}, unitUUID) if err != nil { return nil, errors.Errorf("preparing get endpoints statement: %w", err) diff --git a/domain/port/state/state_test.go b/domain/port/state/state_test.go index e65bfd011ce..bbfa1a9b93a 100644 --- a/domain/port/state/state_test.go +++ b/domain/port/state/state_test.go @@ -196,7 +196,7 @@ func (s *stateSuite) TestGetAllOpenedPorts(c *gc.C) { {Protocol: "tcp", FromPort: 443, ToPort: 443}, {Protocol: "udp", FromPort: 2000, ToPort: 2500}, }, - "endpoint-2": { + "ep1": { {Protocol: "udp", FromPort: 2000, ToPort: 2500}, }, }, network.GroupedPortRanges{}) @@ -573,6 +573,36 @@ func (s *stateSuite) TestUpdateUnitPortsOpenPort(c *gc.C) { c.Check(groupedPortRanges["ep1"][0], jc.DeepEquals, network.PortRange{Protocol: "tcp", FromPort: 8080, ToPort: 8080}) } +func (s *stateSuite) TestUpdateUnitPortsOpenPortWildcardEndpoint(c *gc.C) { + st := NewState(s.TxnRunnerFactory()) + ctx := context.Background() + + err := st.RunAtomic(ctx, func(ctx domain.AtomicContext) error { + return st.UpdateUnitPorts(ctx, s.unitUUID, network.GroupedPortRanges{ + port.WildcardEndpoint: {{Protocol: "tcp", FromPort: 1000, ToPort: 1500}}, + }, network.GroupedPortRanges{}) + }) + c.Assert(err, jc.ErrorIsNil) + + groupedPortRanges, err := st.GetUnitOpenedPorts(ctx, s.unitUUID) + c.Assert(err, jc.ErrorIsNil) + c.Check(groupedPortRanges, gc.HasLen, 1) + c.Check(groupedPortRanges[port.WildcardEndpoint], gc.HasLen, 1) + c.Check(groupedPortRanges[port.WildcardEndpoint][0], jc.DeepEquals, network.PortRange{Protocol: "tcp", FromPort: 1000, ToPort: 1500}) +} + +func (s *stateSuite) TestUpdateUnitPortsOpenOnInvalidEndpoint(c *gc.C) { + st := NewState(s.TxnRunnerFactory()) + ctx := context.Background() + + err := st.RunAtomic(ctx, func(ctx domain.AtomicContext) error { + return st.UpdateUnitPorts(ctx, s.unitUUID, network.GroupedPortRanges{ + "invalid": {{Protocol: "tcp", FromPort: 1000, ToPort: 1500}}, + }, network.GroupedPortRanges{}) + }) + c.Assert(err, jc.ErrorIs, porterrors.InvalidEndpoint) +} + func (s *stateSuite) TestUpdateUnitPortsClosePort(c *gc.C) { st := NewState(s.TxnRunnerFactory()) ctx := context.Background() diff --git a/domain/port/state/types.go b/domain/port/state/types.go index e2b6ba05f0c..ec22b327f17 100644 --- a/domain/port/state/types.go +++ b/domain/port/state/types.go @@ -117,11 +117,12 @@ type portRangeUUIDs []string // unitPortRange represents a range of ports for a given protocol by id for a // given unit's endpoint by uuid. type unitPortRange struct { - UUID string `db:"uuid"` - ProtocolID int `db:"protocol_id"` - FromPort int `db:"from_port"` - ToPort int `db:"to_port"` - UnitEndpointUUID string `db:"unit_endpoint_uuid"` + UUID string `db:"uuid"` + ProtocolID int `db:"protocol_id"` + FromPort int `db:"from_port"` + ToPort int `db:"to_port"` + RelationUUID string `db:"relation_uuid,omitempty"` + UnitUUID unit.UUID `db:"unit_uuid"` } // endpoint represents a network endpoint and its UUID. @@ -130,10 +131,6 @@ type endpoint struct { Endpoint string `db:"endpoint"` } -type endpointUUID struct { - UUID string `db:"uuid"` -} - // endpointName represents a network endpoint's name. type endpointName struct { Endpoint string `db:"endpoint"` @@ -142,12 +139,7 @@ type endpointName struct { // endpoints represents a list of network endpoints. type endpoints []string -// unitEndpoint represents a unit's endpoint and its UUID. -type unitEndpoint struct { - UUID string `db:"uuid"` - Endpoint string `db:"endpoint"` - UnitUUID unit.UUID `db:"unit_uuid"` -} +type unitUUIDs []unit.UUID // unitUUID represents a unit's UUID. type unitUUID struct { diff --git a/domain/port/state/watcher.go b/domain/port/state/watcher.go index 5faf1a9ca2a..367ff709189 100644 --- a/domain/port/state/watcher.go +++ b/domain/port/state/watcher.go @@ -13,6 +13,7 @@ import ( coreapplication "github.com/juju/juju/core/application" coremachine "github.com/juju/juju/core/machine" + "github.com/juju/juju/core/unit" "github.com/juju/juju/internal/errors" ) @@ -27,74 +28,70 @@ func (st *State) InitialWatchMachineOpenedPortsStatement() string { return "SELECT name FROM machine" } -// GetMachineNamesForUnitEndpoints returns a slice of machine names that host the provided endpoints. -func (st *State) GetMachineNamesForUnitEndpoints(ctx context.Context, eps []string) ([]coremachine.Name, error) { +// GetMachineNamesForUnits returns a slice of machine names that host the provided units. +func (st *State) GetMachineNamesForUnits(ctx context.Context, units []unit.UUID) ([]coremachine.Name, error) { db, err := st.DB() if err != nil { return nil, jujuerrors.Trace(err) } - endpointUUIDs := endpoints(eps) + unitUUIDs := unitUUIDs(units) query, err := st.Prepare(` SELECT DISTINCT machine.name AS &machineName.name FROM machine JOIN unit ON machine.net_node_uuid = unit.net_node_uuid -JOIN unit_endpoint ON unit.uuid = unit_endpoint.unit_uuid -WHERE unit_endpoint.uuid IN ($endpoints[:]) -`, machineName{}, endpointUUIDs) +WHERE unit.uuid IN ($unitUUIDs[:]) +`, machineName{}, unitUUIDs) if err != nil { - return nil, errors.Errorf("failed to prepare machine for endpoint query: %w", err) + return nil, errors.Errorf("failed to prepare machine for unit query: %w", err) } machineNames := []machineName{} err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - err := tx.Query(ctx, query, endpointUUIDs).GetAll(&machineNames) + err := tx.Query(ctx, query, unitUUIDs).GetAll(&machineNames) if errors.Is(err, sqlair.ErrNoRows) { return nil } return jujuerrors.Trace(err) }) if err != nil { - return nil, errors.Errorf("failed to get machines for endpoints: %w", err) + return nil, errors.Errorf("failed to get machines for units: %w", err) } return transform.Slice(machineNames, func(m machineName) coremachine.Name { return m.Name }), nil } -// FilterEndpointsForApplication returns the subset of provided endpoint uuids -// that are associated with the provided application. -func (st *State) FilterEndpointsForApplication(ctx context.Context, eps []string, app coreapplication.ID) (set.Strings, error) { +func (st *State) FilterUnitUUIDsForApplication(ctx context.Context, units []unit.UUID, app coreapplication.ID) (set.Strings, error) { db, err := st.DB() if err != nil { return nil, jujuerrors.Trace(err) } applicationUUID := applicationUUID{UUID: app} - endpointUUIDs := endpoints(eps) + unitUUIDs := unitUUIDs(units) query, err := st.Prepare(` -SELECT unit_endpoint.uuid AS &endpointUUID.uuid +SELECT uuid AS &unitUUID.unit_uuid FROM unit -JOIN unit_endpoint ON unit.uuid = unit_endpoint.unit_uuid -WHERE unit_endpoint.uuid IN ($endpoints[:]) +WHERE unit.uuid IN ($unitUUIDs[:]) AND unit.application_uuid = $applicationUUID.application_uuid -`, endpointUUID{}, applicationUUID, endpointUUIDs) +`, unitUUID{}, applicationUUID, unitUUIDs) if err != nil { - return nil, errors.Errorf("failed to prepare application for endpoint query: %w", err) + return nil, errors.Errorf("failed to prepare application for unit query: %w", err) } - filteredEps := []endpointUUID{} + filteredUnits := []unitUUID{} err = db.Txn(ctx, func(ctx context.Context, tx *sqlair.TX) error { - err := tx.Query(ctx, query, applicationUUID, endpointUUIDs).GetAll(&filteredEps) + err := tx.Query(ctx, query, applicationUUID, unitUUIDs).GetAll(&filteredUnits) if errors.Is(err, sqlair.ErrNoRows) { return nil } return jujuerrors.Trace(err) }) if err != nil { - return nil, errors.Errorf("failed to get applications for endpoints: %w", err) + return nil, errors.Errorf("failed to get applications for units: %w", err) } - return set.NewStrings(transform.Slice(filteredEps, func(e endpointUUID) string { return e.UUID })...), nil + return set.NewStrings(transform.Slice(filteredUnits, func(u unitUUID) string { return u.UUID.String() })...), nil } diff --git a/domain/port/state/watcher_test.go b/domain/port/state/watcher_test.go index b783bd26058..6a9f947a1bd 100644 --- a/domain/port/state/watcher_test.go +++ b/domain/port/state/watcher_test.go @@ -6,16 +6,13 @@ package state import ( "context" - "github.com/canonical/sqlair" "github.com/juju/collections/set" jc "github.com/juju/testing/checkers" gc "gopkg.in/check.v1" coreapplication "github.com/juju/juju/core/application" "github.com/juju/juju/core/machine" - "github.com/juju/juju/core/network" coreunit "github.com/juju/juju/core/unit" - "github.com/juju/juju/domain" machinestate "github.com/juju/juju/domain/machine/state" "github.com/juju/juju/internal/logger" ) @@ -30,12 +27,6 @@ type watcherSuite struct { var _ = gc.Suite(&watcherSuite{}) -var ( - ssh = network.PortRange{Protocol: "tcp", FromPort: 22, ToPort: 22} - http = network.PortRange{Protocol: "tcp", FromPort: 80, ToPort: 80} - https = network.PortRange{Protocol: "tcp", FromPort: 443, ToPort: 443} -) - func (s *watcherSuite) SetUpTest(c *gc.C) { s.ModelSuite.SetUpTest(c) @@ -54,48 +45,6 @@ func (s *watcherSuite) SetUpTest(c *gc.C) { s.unitUUIDs[2], _ = s.createUnit(c, netNodeUUIDs[1], appNames[1]) } -func (s *watcherSuite) initialiseOpenPorts(c *gc.C, st *State) ([]string, map[string]string) { - err := st.RunAtomic(context.Background(), func(ctx domain.AtomicContext) error { - err := st.UpdateUnitPorts(ctx, s.unitUUIDs[0], network.GroupedPortRanges{ - "ep0": {ssh}, - }, network.GroupedPortRanges{}) - if err != nil { - return err - } - err = st.UpdateUnitPorts(ctx, s.unitUUIDs[1], network.GroupedPortRanges{ - "ep1": {http}, - }, network.GroupedPortRanges{}) - if err != nil { - return err - } - return st.UpdateUnitPorts(ctx, s.unitUUIDs[2], network.GroupedPortRanges{ - "ep2": {https}, - }, network.GroupedPortRanges{}) - }) - c.Assert(err, jc.ErrorIsNil) - - query, err := sqlair.Prepare(` -SELECT &endpoint.* -FROM unit_endpoint -`, endpoint{}) - c.Assert(err, jc.ErrorIsNil) - - var endpoints []endpoint - err = s.TxnRunner().Txn(context.Background(), func(ctx context.Context, tx *sqlair.TX) error { - return tx.Query(ctx, query).GetAll(&endpoints) - }) - c.Assert(err, jc.ErrorIsNil) - - endpointToUUIDMap := make(map[string]string) - endpointUUIDs := make([]string, len(endpoints)) - for i, ep := range endpoints { - endpointToUUIDMap[ep.Endpoint] = ep.UUID - endpointUUIDs[i] = ep.UUID - } - - return endpointUUIDs, endpointToUUIDMap -} - /* * The following tests will run with the following context: * - 3 units are deployed (with uuids stored in s.unitUUIDs) @@ -105,23 +54,17 @@ FROM unit_endpoint * - on 2 applications (with names stored in appNames; uuids s.appUUIDs) * - unit 0 is deployed to app 0 * - units 1 & 2 are deployed to app 1 -* -* - The following ports are open: -* - ssh is open on endpoint 0 on unit 0 -* - http is open on endpoint 1 on unit 1 -* - https is open on endpoint 2 on unit 2 */ func (s *watcherSuite) TestGetMachinesForUnitEndpoints(c *gc.C) { st := NewState(s.TxnRunnerFactory()) ctx := context.Background() - endpointUUIDs, endpointToUUIDMap := s.initialiseOpenPorts(c, st) - machineUUIDsForEndpoint, err := st.GetMachineNamesForUnitEndpoints(ctx, endpointUUIDs) + machineUUIDsForEndpoint, err := st.GetMachineNamesForUnits(ctx, s.unitUUIDs[:]) c.Assert(err, jc.ErrorIsNil) c.Check(machineUUIDsForEndpoint, jc.SameContents, []machine.Name{"0", "1"}) - machineUUIDsForEndpoint, err = st.GetMachineNamesForUnitEndpoints(ctx, []string{endpointToUUIDMap["ep0"]}) + machineUUIDsForEndpoint, err = st.GetMachineNamesForUnits(ctx, []coreunit.UUID{s.unitUUIDs[0]}) c.Assert(err, jc.ErrorIsNil) c.Check(machineUUIDsForEndpoint, jc.DeepEquals, []machine.Name{"0"}) } @@ -129,13 +72,12 @@ func (s *watcherSuite) TestGetMachinesForUnitEndpoints(c *gc.C) { func (s *watcherSuite) TestFilterEndpointForApplication(c *gc.C) { st := NewState(s.TxnRunnerFactory()) ctx := context.Background() - endpointUUIDs, endpointToUUIDMap := s.initialiseOpenPorts(c, st) - filteredEndpointUUIDs, err := st.FilterEndpointsForApplication(ctx, endpointUUIDs, s.appUUIDs[0]) + filteredUnits, err := st.FilterUnitUUIDsForApplication(ctx, s.unitUUIDs[:], s.appUUIDs[0]) c.Assert(err, jc.ErrorIsNil) - c.Check(filteredEndpointUUIDs, jc.DeepEquals, set.NewStrings(endpointToUUIDMap["ep0"])) + c.Check(filteredUnits, jc.DeepEquals, set.NewStrings(s.unitUUIDs[0].String())) - filteredEndpointUUIDs, err = st.FilterEndpointsForApplication(ctx, endpointUUIDs, s.appUUIDs[1]) + filteredUnits, err = st.FilterUnitUUIDsForApplication(ctx, s.unitUUIDs[:], s.appUUIDs[1]) c.Assert(err, jc.ErrorIsNil) - c.Check(filteredEndpointUUIDs, jc.DeepEquals, set.NewStrings(endpointToUUIDMap["ep1"], endpointToUUIDMap["ep2"])) + c.Check(filteredUnits, jc.DeepEquals, set.NewStrings(s.unitUUIDs[1].String(), s.unitUUIDs[2].String())) } diff --git a/domain/port/watcher_test.go b/domain/port/watcher_test.go index ca7cbf2e295..f6defd7c577 100644 --- a/domain/port/watcher_test.go +++ b/domain/port/watcher_test.go @@ -19,13 +19,11 @@ import ( "github.com/juju/juju/domain" "github.com/juju/juju/domain/application" "github.com/juju/juju/domain/application/charm" - applicationerrors "github.com/juju/juju/domain/application/errors" applicationstate "github.com/juju/juju/domain/application/state" machinestate "github.com/juju/juju/domain/machine/state" "github.com/juju/juju/domain/port/service" "github.com/juju/juju/domain/port/state" changestreamtesting "github.com/juju/juju/internal/changestream/testing" - "github.com/juju/juju/internal/errors" "github.com/juju/juju/internal/logger" loggertesting "github.com/juju/juju/internal/logger/testing" ) @@ -76,27 +74,45 @@ func (s *watcherSuite) SetUpTest(c *gc.C) { c.Assert(err, jc.ErrorIsNil) } +func (s *watcherSuite) createApplicationWithRelations(c *gc.C, appName string, relations ...string) coreapplication.ID { + relationsMap := map[string]charm.Relation{} + for _, relation := range relations { + relationsMap[relation] = charm.Relation{ + Name: relation, + Key: relation, + Role: charm.RoleRequirer, + Scope: charm.ScopeGlobal, + } + } + + applicationSt := applicationstate.NewApplicationState(s.TxnRunnerFactory(), logger.GetLogger("juju.test.application")) + var appUUID coreapplication.ID + err := applicationSt.RunAtomic(context.Background(), func(ctx domain.AtomicContext) error { + var err error + appUUID, err = applicationSt.CreateApplication(ctx, appName, application.AddApplicationArg{ + Charm: charm.Charm{ + Metadata: charm.Metadata{ + Name: appName, + Requires: relationsMap, + }, + }, + }) + return err + }) + c.Assert(err, jc.ErrorIsNil) + return appUUID +} + // createUnit creates a new unit in state and returns its UUID. The unit is assigned // to the net node with uuid `netNodeUUID`. -func (s *watcherSuite) createUnit(c *gc.C, netNodeUUID, appName string) (coreunit.UUID, coreapplication.ID) { +func (s *watcherSuite) createUnit(c *gc.C, netNodeUUID, appName string) coreunit.UUID { applicationSt := applicationstate.NewApplicationState(s.TxnRunnerFactory(), logger.GetLogger("juju.test.application")) unitName, err := coreunit.NewNameFromParts(appName, s.unitCount) c.Assert(err, jc.ErrorIsNil) err = applicationSt.RunAtomic(context.Background(), func(ctx domain.AtomicContext) error { appID, err := applicationSt.GetApplicationID(ctx, appName) - if err != nil && !errors.Is(err, applicationerrors.ApplicationNotFound) { - return err - } if err != nil { - if appID, err = applicationSt.CreateApplication(ctx, appName, application.AddApplicationArg{ - Charm: charm.Charm{ - Metadata: charm.Metadata{ - Name: appName, - }, - }, - }); err != nil { - return err - } + return err } err = applicationSt.AddUnits(ctx, appID, application.AddUnitArg{UnitName: unitName}) if err != nil { @@ -107,21 +123,13 @@ func (s *watcherSuite) createUnit(c *gc.C, netNodeUUID, appName string) (coreuni }) c.Assert(err, jc.ErrorIsNil) - var ( - unitUUID coreunit.UUID - appUUID coreapplication.ID - ) + var unitUUID coreunit.UUID err = s.TxnRunner().StdTxn(context.Background(), func(ctx context.Context, tx *sql.Tx) error { err := tx.QueryRowContext(ctx, "SELECT uuid FROM unit WHERE name = ?", unitName).Scan(&unitUUID) if err != nil { return err } - err = tx.QueryRowContext(ctx, "SELECT uuid FROM application WHERE name = ?", appName).Scan(&appUUID) - if err != nil { - return err - } - _, err = tx.ExecContext(ctx, "INSERT INTO net_node VALUES (?) ON CONFLICT DO NOTHING", netNodeUUID) if err != nil { return err @@ -135,7 +143,7 @@ func (s *watcherSuite) createUnit(c *gc.C, netNodeUUID, appName string) (coreuni return nil }) c.Assert(err, jc.ErrorIsNil) - return unitUUID, appUUID + return unitUUID } /* @@ -150,9 +158,12 @@ func (s *watcherSuite) createUnit(c *gc.C, netNodeUUID, appName string) (coreuni */ func (s *watcherSuite) TestWatchMachinePortRanges(c *gc.C) { - s.unitUUIDs[0], s.appUUIDs[0] = s.createUnit(c, netNodeUUIDs[0], appNames[0]) - s.unitUUIDs[1], s.appUUIDs[1] = s.createUnit(c, netNodeUUIDs[0], appNames[1]) - s.unitUUIDs[2], _ = s.createUnit(c, netNodeUUIDs[1], appNames[1]) + s.appUUIDs[0] = s.createApplicationWithRelations(c, appNames[0], "ep0", "ep1", "ep2", "ep3") + s.appUUIDs[1] = s.createApplicationWithRelations(c, appNames[1], "ep0", "ep1", "ep2", "ep3") + + s.unitUUIDs[0] = s.createUnit(c, netNodeUUIDs[0], appNames[0]) + s.unitUUIDs[1] = s.createUnit(c, netNodeUUIDs[0], appNames[1]) + s.unitUUIDs[2] = s.createUnit(c, netNodeUUIDs[1], appNames[1]) watcher, err := s.srv.WatchMachineOpenedPorts(context.Background()) c.Assert(err, jc.ErrorIsNil) @@ -257,9 +268,12 @@ func (s *watcherSuite) TestWatchMachinePortRanges(c *gc.C) { } func (s *watcherSuite) TestWatchOpenedPortsForApplication(c *gc.C) { - s.unitUUIDs[0], s.appUUIDs[0] = s.createUnit(c, netNodeUUIDs[0], appNames[0]) - s.unitUUIDs[1], s.appUUIDs[1] = s.createUnit(c, netNodeUUIDs[0], appNames[1]) - s.unitUUIDs[2], _ = s.createUnit(c, netNodeUUIDs[1], appNames[1]) + s.appUUIDs[0] = s.createApplicationWithRelations(c, appNames[0], "ep0", "ep1", "ep2") + s.appUUIDs[1] = s.createApplicationWithRelations(c, appNames[1], "ep0", "ep1", "ep2") + + s.unitUUIDs[0] = s.createUnit(c, netNodeUUIDs[0], appNames[0]) + s.unitUUIDs[1] = s.createUnit(c, netNodeUUIDs[0], appNames[1]) + s.unitUUIDs[2] = s.createUnit(c, netNodeUUIDs[1], appNames[1]) watcher, err := s.srv.WatchOpenedPortsForApplication(context.Background(), s.appUUIDs[1]) c.Assert(err, jc.ErrorIsNil) diff --git a/domain/schema/model.go b/domain/schema/model.go index ea867f2fcfd..9258438f222 100644 --- a/domain/schema/model.go +++ b/domain/schema/model.go @@ -23,6 +23,7 @@ import ( //go:generate go run ./../../generate/triggergen -db=model -destination=./model/triggers/charm.gen.go -package=triggers -tables=charm //go:generate go run ./../../generate/triggergen -db=model -destination=./model/triggers/unit.gen.go -package=triggers -tables=unit //go:generate go run ./../../generate/triggergen -db=model -destination=./model/triggers/application-scale.gen.go -package=triggers -tables=application_scale +//go:generate go run ./../../generate/triggergen -db=model -destination=./model/triggers/port-range.gen.go -package=triggers -tables=port_range //go:embed model/sql/*.sql var modelSchemaDir embed.FS @@ -111,7 +112,7 @@ func ModelDDL() *schema.Schema { triggers.ChangeLogTriggersForCharm("uuid", tableCharm), triggers.ChangeLogTriggersForUnit("uuid", tableUnit), triggers.ChangeLogTriggersForApplicationScale("application_uuid", tableApplicationScale), - triggers.ChangeLogTriggersForPortRange("unit_endpoint_uuid", tablePortRange), + triggers.ChangeLogTriggersForPortRange("unit_uuid", tablePortRange), triggers.ChangeLogTriggersForSecretDeletedValueRef("revision_uuid", tableSecretDeletedValueRef), ) diff --git a/domain/schema/model/sql/0021-port-ranges.sql b/domain/schema/model/sql/0021-port-ranges.sql index 3be56041b34..d5396ff8055 100644 --- a/domain/schema/model/sql/0021-port-ranges.sql +++ b/domain/schema/model/sql/0021-port-ranges.sql @@ -10,49 +10,49 @@ INSERT INTO protocol VALUES CREATE TABLE port_range ( uuid TEXT NOT NULL PRIMARY KEY, - unit_endpoint_uuid TEXT NOT NULL, protocol_id INT NOT NULL, from_port INT, to_port INT, + relation_uuid TEXT, -- NULL-able, where null represents a wildcard endpoint + unit_uuid TEXT NOT NULL, CONSTRAINT fk_port_range_protocol FOREIGN KEY (protocol_id) REFERENCES protocol (id), - CONSTRAINT fk_port_range_unit_endpoint - FOREIGN KEY (unit_endpoint_uuid) - REFERENCES unit_endpoint (uuid) + CONSTRAINT fk_port_range_relation + FOREIGN KEY (relation_uuid) + REFERENCES charm_relation (uuid), + CONSTRAINT fk_port_range_unit + FOREIGN KEY (unit_uuid) + REFERENCES unit (uuid) ); -- We disallow overlapping port ranges, however this cannot reasonably -- be enforced in the schema. Including the from_port in the uniqueness -- constraint is as far as we go here. Non-overlapping ranges must be -- enforced in the service/state layer. -CREATE UNIQUE INDEX idx_port_range_endpoint_port_range ON port_range (unit_endpoint_uuid, protocol_id, from_port); +CREATE UNIQUE INDEX idx_port_range_endpoint ON port_range (protocol_id, from_port, relation_uuid, unit_uuid); -CREATE TABLE unit_endpoint ( - uuid TEXT NOT NULL PRIMARY KEY, - endpoint TEXT NOT NULL, - unit_uuid TEXT NOT NULL, - CONSTRAINT fk_endpoint_unit - FOREIGN KEY (unit_uuid) - REFERENCES unit (uuid) -); - -CREATE UNIQUE INDEX idx_unit_endpoint_endpoint_unit_uuid ON unit_endpoint (endpoint, unit_uuid); - --- v_port_range is used to fetch well constructed information about a port range. --- This will include information about the port range's endpoint and unit, as well --- as the uuids for the net node and application each endpoint is located on. CREATE VIEW v_port_range AS SELECT + pr.uuid, pr.from_port, pr.to_port, - protocol.protocol, - ue.endpoint, + pr.unit_uuid, u.name AS unit_name, - u.net_node_uuid, - u.application_uuid + protocol.protocol, + cr."key" AS endpoint FROM port_range AS pr LEFT JOIN protocol ON pr.protocol_id = protocol.id -LEFT JOIN unit_endpoint AS ue ON pr.unit_endpoint_uuid = ue.uuid -LEFT JOIN unit AS u ON ue.unit_uuid = u.uuid; +LEFT JOIN charm_relation AS cr ON pr.relation_uuid = cr.uuid +LEFT JOIN unit AS u ON pr.unit_uuid = u.uuid; + +CREATE VIEW v_endpoint +AS +SELECT + cr.uuid, + cr."key" AS endpoint, + u.uuid AS unit_uuid +FROM unit AS u +LEFT JOIN application AS a ON u.application_uuid = a.uuid +LEFT JOIN charm_relation AS cr ON a.charm_uuid = cr.charm_uuid; diff --git a/domain/schema/model/triggers/port-range.gen.go b/domain/schema/model/triggers/port-range.gen.go index 53df8d8e5e8..f6db9cfb3d7 100644 --- a/domain/schema/model/triggers/port-range.gen.go +++ b/domain/schema/model/triggers/port-range.gen.go @@ -29,11 +29,12 @@ END; CREATE TRIGGER trg_log_port_range_update AFTER UPDATE ON port_range FOR EACH ROW WHEN - (NEW.uuid != OLD.uuid OR (NEW.uuid IS NOT NULL AND OLD.uuid IS NULL) OR (NEW.uuid IS NULL AND OLD.uuid IS NOT NULL)) OR - NEW.unit_endpoint_uuid != OLD.unit_endpoint_uuid OR + NEW.uuid != OLD.uuid OR NEW.protocol_id != OLD.protocol_id OR (NEW.from_port != OLD.from_port OR (NEW.from_port IS NOT NULL AND OLD.from_port IS NULL) OR (NEW.from_port IS NULL AND OLD.from_port IS NOT NULL)) OR - (NEW.to_port != OLD.to_port OR (NEW.to_port IS NOT NULL AND OLD.to_port IS NULL) OR (NEW.to_port IS NULL AND OLD.to_port IS NOT NULL)) + (NEW.to_port != OLD.to_port OR (NEW.to_port IS NOT NULL AND OLD.to_port IS NULL) OR (NEW.to_port IS NULL AND OLD.to_port IS NOT NULL)) OR + (NEW.relation_uuid != OLD.relation_uuid OR (NEW.relation_uuid IS NOT NULL AND OLD.relation_uuid IS NULL) OR (NEW.relation_uuid IS NULL AND OLD.relation_uuid IS NOT NULL)) OR + NEW.unit_uuid != OLD.unit_uuid BEGIN INSERT INTO change_log (edit_type_id, namespace_id, changed, created_at) VALUES (2, %[2]d, OLD.%[1]s, DATETIME('now')); diff --git a/domain/schema/schema_test.go b/domain/schema/schema_test.go index 7ff6ac47fa8..579fa6a6325 100644 --- a/domain/schema/schema_test.go +++ b/domain/schema/schema_test.go @@ -483,7 +483,6 @@ func (s *schemaSuite) TestModelTables(c *gc.C) { // Opened Ports "protocol", "port_range", - "unit_endpoint", ) got := readEntityNames(c, s.DB(), "table") wanted := expected.Union(internalTableNames) @@ -515,6 +514,7 @@ func (s *schemaSuite) TestModelViews(c *gc.C) { "v_charm_storage", "v_hardware_characteristics", "v_port_range", + "v_endpoint", "v_secret_permission", "v_space_subnet", diff --git a/testcharms/charm-repo/quantal/mysql/metadata.yaml b/testcharms/charm-repo/quantal/mysql/metadata.yaml index 3659c281111..79093e47f66 100644 --- a/testcharms/charm-repo/quantal/mysql/metadata.yaml +++ b/testcharms/charm-repo/quantal/mysql/metadata.yaml @@ -2,6 +2,8 @@ name: mysql summary: "Database engine" description: "A pretty popular database" provides: + db: + interface: db server: interface: mysql server-admin: