Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/openapi/gateway.swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,9 @@
},
"tcp": {
"type": "boolean"
},
"blockNetwork": {
"type": "boolean"
}
}
},
Expand Down
1 change: 1 addition & 0 deletions pkg/abstractions/pod/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (i *podInstance) startContainers(containersToRun int) error {
Stub: *i.Stub,
CheckpointEnabled: checkpointEnabled,
Ports: ports,
BlockNetwork: i.StubConfig.BlockNetwork,
}

ttl := time.Duration(i.StubConfig.KeepWarmSeconds) * time.Second
Expand Down
1 change: 1 addition & 0 deletions pkg/abstractions/pod/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ func (s *GenericPodService) run(ctx context.Context, authInfo *auth.AuthInfo, st
Ports: ports,
CheckpointEnabled: checkpointEnabled,
Checkpoint: checkpoint,
BlockNetwork: stubConfig.BlockNetwork,
})
if err != nil {
return "", err
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/gateway.proto
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ message GetOrCreateStubRequest {
Schema inputs = 35;
Schema outputs = 36;
bool tcp = 37;
bool block_network = 38;
}

message GetOrCreateStubResponse {
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/services/stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func (gws *GatewayService) GetOrCreateStub(ctx context.Context, in *pb.GetOrCrea
Inputs: inputs,
Outputs: outputs,
TCP: in.Tcp,
BlockNetwork: in.BlockNetwork,
}

// Ensure GPU count is at least 1 if a GPU is required
Expand Down
1 change: 1 addition & 0 deletions pkg/types/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ type StubConfigV1 struct {
Inputs *Schema `json:"inputs"`
Outputs *Schema `json:"outputs"`
TCP bool `json:"tcp"`
BlockNetwork bool `json:"block_network"`
}

type StubConfigLimitedValues struct {
Expand Down
3 changes: 3 additions & 0 deletions pkg/types/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ type ContainerRequest struct {
AppId string `json:"app_id"`
Checkpoint *Checkpoint `json:"checkpoint"`
ConfigPath string `json:"config_path"`
BlockNetwork bool `json:"block_network"`
}

func (c *ContainerRequest) RequiresGPU() bool {
Expand Down Expand Up @@ -297,6 +298,7 @@ func (c *ContainerRequest) ToProto() *pb.ContainerRequest {
Ports: c.Ports,
CheckpointEnabled: c.CheckpointEnabled,
Checkpoint: checkpoint,
BlockNetwork: c.BlockNetwork,
}
}

Expand Down Expand Up @@ -345,6 +347,7 @@ func NewContainerRequestFromProto(in *pb.ContainerRequest) *ContainerRequest {
BuildOptions: bo,
Ports: in.Ports,
Checkpoint: checkpoint,
BlockNetwork: in.BlockNetwork,
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message ContainerRequest {
string app_id = 23;
Checkpoint checkpoint = 24;
string config_path = 25;
bool block_network = 26;
}

message ContainerState {
Expand Down
61 changes: 59 additions & 2 deletions pkg/worker/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,30 @@ func (m *ContainerNetworkManager) Setup(containerId string, spec *specs.Spec, re
if err != nil {
return err
}
defer netns.Set(hostNS) // Reset to the original namespace after setting up the container network

return m.configureContainerNetwork(&containerNetworkConfigOpts{
err = m.configureContainerNetwork(&containerNetworkConfigOpts{
containerId: containerId,
containerVeth: containerVeth,
request: request,
})

// Switch back to host namespace before setting up BlockNetwork rules
if nsErr := netns.Set(hostNS); nsErr != nil {
return fmt.Errorf("failed to switch back to host namespace: %w", nsErr)
}

if err != nil {
return err
}

// Block network in host namespace (must be done in host namespace to affect forwarding)
if request.BlockNetwork {
if err := m.setupBlockNetwork(containerId, request); err != nil {
return err
}
}

return nil
}

func (m *ContainerNetworkManager) createVethPair(hostVethName, containerVethName string) error {
Expand Down Expand Up @@ -512,6 +529,46 @@ func (m *ContainerNetworkManager) configureContainerNetwork(opts *containerNetwo
return nil
}

func (m *ContainerNetworkManager) setupBlockNetwork(containerId string, request *types.ContainerRequest) error {
// Get container IP
containerIpResponse, err := handleGRPCResponse(m.workerRepoClient.GetContainerIp(m.ctx, &pb.GetContainerIpRequest{
NetworkPrefix: m.networkPrefix,
ContainerId: containerId,
}))
if err != nil {
return err
}

containerIp := containerIpResponse.IpAddress

// Block IPv4 outbound traffic (but allow reply packets for exposed ports)
err = m.ipt.InsertUnique("filter", "FORWARD", 1, "-s", containerIp, "-o", m.defaultLink.Attrs().Name, "-m", "conntrack", "!", "--ctstate", "ESTABLISHED,RELATED", "-j", "DROP")
if err != nil {
return err
}

// Block IPv6 outbound traffic if enabled (but allow reply packets for exposed ports)
if m.ipt6 != nil {
// Calculate the corresponding IPv6 address using the last octet of the IPv4 address
ip := net.ParseIP(containerIp)
if ip == nil {
return fmt.Errorf("invalid IPv4 address: %s", containerIp)
}
ipv4LastOctet := int(ip.To4()[3])
_, ipv6Net, _ := net.ParseCIDR(containerSubnetIPv6)
ipv6Prefix := ipv6Net.IP.String()
ipv6Address := fmt.Sprintf("%s%x", ipv6Prefix, ipv4LastOctet)

err = m.ipt6.InsertUnique("filter", "FORWARD", 1, "-s", ipv6Address, "-o", m.defaultLink.Attrs().Name, "-m", "conntrack", "!", "--ctstate", "ESTABLISHED,RELATED", "-j", "DROP")
if err != nil {
return err
}
}

log.Info().Str("container_id", containerId).Str("ip_address", containerIp).Msg("outbound network access blocked for container")
return nil
}

func (m *ContainerNetworkManager) cleanupOrphanedNamespaces() {
ticker := time.NewTicker(containerNetworkCleanupInterval)
defer ticker.Stop()
Expand Down
Loading
Loading