Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/_nebari/stages/infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class AWSInputVars(schema.Base):
kubernetes_version: str
eks_endpoint_access: Optional[
Literal["private", "public", "public_and_private"]
] = "public"
] = "public_and_private"
eks_kms_arn: Optional[str] = None
eks_public_access_cidrs: Optional[List[str]] = ["0.0.0.0/0"]
node_groups: List[AWSNodeGroupInputVars]
Expand Down Expand Up @@ -457,7 +457,7 @@ class AmazonWebServicesProvider(schema.Base):
node_groups: Dict[str, AWSNodeGroup] = DEFAULT_AWS_NODE_GROUPS
eks_endpoint_access: Optional[
Literal["private", "public", "public_and_private"]
] = "public"
] = "public_and_private"
eks_public_access_cidrs: Optional[List[str]] = ["0.0.0.0/0"]
eks_kms_arn: Optional[str] = None
existing_subnet_ids: Optional[List[str]] = None
Expand Down
13 changes: 7 additions & 6 deletions src/_nebari/stages/infrastructure/template/aws/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ data "aws_partition" "current" {}

locals {
# Only override_network if both existing_subnet_ids and existing_security_group_id are not null.
override_network = (var.existing_subnet_ids != null) && (var.existing_security_group_id != null)
subnet_ids = local.override_network ? var.existing_subnet_ids : module.network[0].subnet_ids
security_group_id = local.override_network ? var.existing_security_group_id : module.network[0].security_group_id
partition = data.aws_partition.current.partition
override_network = (var.existing_subnet_ids != null) && (var.existing_security_group_id != null)
private_subnet_ids = local.override_network ? var.existing_subnet_ids : module.network[0].private_subnet_ids
security_group_id = local.override_network ? var.existing_security_group_id : module.network[0].security_group_id
partition = data.aws_partition.current.partition
}

# ==================== ACCOUNTING ======================
Expand Down Expand Up @@ -50,6 +50,7 @@ module "network" {

vpc_cidr_block = var.vpc_cidr_block
aws_availability_zones = length(var.availability_zones) >= 2 ? var.availability_zones : slice(sort(data.aws_availability_zones.awszones.names), 0, 2)
region = var.region
}


Expand All @@ -70,7 +71,7 @@ module "efs" {
name = "${local.cluster_name}-jupyterhub-shared"
tags = local.additional_tags

efs_subnets = local.subnet_ids
efs_subnets = local.private_subnet_ids
efs_security_groups = [local.security_group_id]
}

Expand All @@ -88,7 +89,7 @@ module "kubernetes" {
region = var.region
kubernetes_version = var.kubernetes_version

cluster_subnets = local.subnet_ids
cluster_subnets = local.private_subnet_ids
cluster_security_groups = [local.security_group_id]

node_group_additional_policies = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,36 @@ resource "aws_vpc" "main" {
tags = merge({ Name = var.name }, var.tags, var.vpc_tags)
}

resource "aws_subnet" "main" {
resource "aws_subnet" "public" {
count = length(var.aws_availability_zones)

availability_zone = var.aws_availability_zones[count.index]
cidr_block = cidrsubnet(var.vpc_cidr_block, var.vpc_cidr_newbits, count.index)
vpc_id = aws_vpc.main.id
map_public_ip_on_launch = true
availability_zone = var.aws_availability_zones[count.index]
cidr_block = cidrsubnet(var.vpc_cidr_block, var.vpc_cidr_newbits, count.index)
vpc_id = aws_vpc.main.id

tags = merge({ Name = "${var.name}-subnet-${count.index}" }, var.tags, var.subnet_tags)
tags = merge({ Name = "${var.name}-pulbic-subnet-${count.index}", "kubernetes.io/role/elb" = 1 }, var.tags, var.subnet_tags)

lifecycle {
ignore_changes = [
availability_zone
]
}
}

moved {
from = aws_subnet.main
to = aws_subnet.public
}


resource "aws_subnet" "private" {
count = length(var.aws_availability_zones)

availability_zone = var.aws_availability_zones[count.index]
cidr_block = cidrsubnet(var.vpc_cidr_block, var.vpc_cidr_newbits, count.index + length(var.aws_availability_zones))
vpc_id = aws_vpc.main.id

tags = merge({ Name = "${var.name}-private-subnet-${count.index}" }, var.tags, var.subnet_tags)

lifecycle {
ignore_changes = [
Expand All @@ -30,7 +51,25 @@ resource "aws_internet_gateway" "main" {
tags = merge({ Name = var.name }, var.tags)
}

resource "aws_route_table" "main" {
resource "aws_eip" "nat-gateway-eip" {
count = length(var.aws_availability_zones)

domain = "vpc"

tags = merge({ Name = "${var.name}-nat-gateway-eip-${count.index}" }, var.tags)
}

resource "aws_nat_gateway" "main" {
count = length(var.aws_availability_zones)

allocation_id = aws_eip.nat-gateway-eip[count.index].id
subnet_id = aws_subnet.public[count.index].id

tags = merge({ Name = "${var.name}-nat-gateway-${count.index}" }, var.tags)
depends_on = [aws_internet_gateway.main]
}

resource "aws_route_table" "public" {
vpc_id = aws_vpc.main.id

route {
Expand All @@ -41,11 +80,36 @@ resource "aws_route_table" "main" {
tags = merge({ Name = var.name }, var.tags)
}

resource "aws_route_table_association" "main" {
moved {
from = aws_route_table.main
to = aws_route_table.public
}

resource "aws_route_table" "private" {
count = length(var.aws_availability_zones)

subnet_id = aws_subnet.main[count.index].id
route_table_id = aws_route_table.main.id
vpc_id = aws_vpc.main.id

route {
cidr_block = "0.0.0.0/0"
gateway_id = aws_nat_gateway.main[count.index].id
}

tags = merge({ Name = var.name }, var.tags)
}

resource "aws_route_table_association" "public" {
count = length(var.aws_availability_zones)

subnet_id = aws_subnet.public[count.index].id
route_table_id = aws_route_table.public.id
}

resource "aws_route_table_association" "private" {
count = length(var.aws_availability_zones)

subnet_id = aws_subnet.private[count.index].id
route_table_id = aws_route_table.private[count.index].id
}

resource "aws_security_group" "main" {
Expand All @@ -62,7 +126,6 @@ resource "aws_security_group" "main" {
cidr_blocks = [var.vpc_cidr_block]
}

#trivy:ignore:AVD-AWS-0104
egress {
description = "Allow all ports and protocols to exit the security group"
from_port = 0
Expand All @@ -73,3 +136,61 @@ resource "aws_security_group" "main" {

tags = merge({ Name = var.name }, var.tags, var.security_group_tags)
}

resource "aws_vpc_endpoint" "s3" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.s3"
vpc_endpoint_type = "Gateway"
route_table_ids = aws_route_table.private[*].id
tags = merge({ Name = "${var.name}-s3-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "ecr_api" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.ecr.api"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-ecr-api-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "ecr_dkr" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.ecr.dkr"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-ecr-dkr-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "elasticloadbalancing" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.elasticloadbalancing"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-elb-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "sts" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.sts"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-sts-endpoint" }, var.tags)
}

resource "aws_vpc_endpoint" "eks" {
vpc_id = aws_vpc.main.id
service_name = "com.amazonaws.${var.region}.eks"
vpc_endpoint_type = "Interface"
private_dns_enabled = true
security_group_ids = [aws_security_group.main.id]
subnet_ids = aws_subnet.private[*].id
tags = merge({ Name = "${var.name}-eks-endpoint" }, var.tags)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ output "security_group_id" {
value = aws_security_group.main.id
}

output "subnet_ids" {
description = "AWS VPC subnet ids"
value = aws_subnet.main[*].id
output "public_subnet_ids" {
description = "AWS VPC public subnet ids"
value = aws_subnet.public[*].id
}

output "private_subnet_ids" {
description = "AWS VPC private subnet ids"
value = aws_subnet.private[*].id
}

output "vpc_id" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,11 @@ variable "vpc_cidr_block" {
variable "vpc_cidr_newbits" {
description = "VPC cidr number of bits to support 2^N subnets"
type = number
default = 2
default = 2 # allows 4 /18 subnets with 16382 addresses each
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default = 2 # allows 4 /18 subnets with 16382 addresses each
default = 3 # allows 8 /18 subnets with 16382 addresses each

needed this for my use case with 3 subnets specified

}

variable "region" {
description = "AWS region to operate infrastructure"
type = string

}
3 changes: 2 additions & 1 deletion src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def check_immutable_fields(self):
# Return a default (mutable) extra field schema if bottom level is not a Pydantic model (such as a free-form 'overrides' block)
if isinstance(bottom_level_schema, BaseModel):
extra_field_schema = schema.ExtraFieldSchema(
**bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {}
**type(bottom_level_schema).model_fields[keys[-1]].json_schema_extra
or {}
)
else:
extra_field_schema = schema.ExtraFieldSchema()
Expand Down
Loading