Newer
Older
import argparse
import logging
import time
import boto3
from botocore.exceptions import ClientError
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PytestContainerManager():
"""
Responsible for spinning up and terminating containers to be used with pytest-xdist
"""
def __init__(self, region, cluster):
self.ecs = boto3.client('ecs', region)
self.cluster_name = cluster
def spin_up_containers(self, number_of_containers, task_name, subnets, security_groups, public_ip, launch_type):
"""
Spins up containers and generates two .txt files, one containing the IP
addresses of the new containers, the other containing their task_arns.
"""
CONTAINER_RUN_TIME_OUT_MINUTES = 10
MAX_RUN_TASK_RETRIES = 7
revision = self.ecs.describe_task_definition(taskDefinition=task_name)['taskDefinition']['revision']
task_definition = "{}:{}".format(task_name, revision)
logging.info("Spinning up {} containers based on task definition: {}".format(number_of_containers, task_definition))
remainder = number_of_containers % 10
quotient = number_of_containers / 10
container_num_list = [10 for i in range(0, quotient)]
if remainder:
container_num_list.append(remainder)
# Boot up containers. boto3's run_task only allows 10 containers to be launched at a time
task_arns = []
for num in container_num_list:
for retry in range(1, MAX_RUN_TASK_RETRIES + 1):
try:
response = self.ecs.run_task(
count=num,
cluster=self.cluster_name,
launchType=launch_type,
networkConfiguration={
'awsvpcConfiguration': {
'subnets': subnets,
'securityGroups': security_groups,
'assignPublicIp': public_ip
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
}
},
taskDefinition=task_definition
)
except ClientError as err:
# Handle AWS throttling with an exponential backoff
if retry == MAX_RUN_TASK_RETRIES:
raise StandardError(
"MAX_RUN_TASK_RETRIES ({}) reached while spinning up containers due to AWS throttling.".format(MAX_RUN_TASK_RETRIES)
)
logger.info("Hit error: {}. Retrying".format(err))
countdown = 2 ** retry
logger.info("Sleeping for {} seconds".format(countdown))
time.sleep(countdown)
else:
break
for task_response in response['tasks']:
task_arns.append(task_response['taskArn'])
# Wait for containers to finish spinning up
not_running = task_arns[:]
ip_addresses = []
all_running = False
for attempt in range(0, CONTAINER_RUN_TIME_OUT_MINUTES * 2):
time.sleep(30)
list_tasks_response = self.ecs.describe_tasks(cluster=self.cluster_name, tasks=not_running)['tasks']
del not_running[:]
for task_response in list_tasks_response:
if task_response['lastStatus'] == 'RUNNING':
for container in task_response['containers']:
ip_addresses.append(container["networkInterfaces"][0]["privateIpv4Address"])
else:
not_running.append(task_response['taskArn'])
if not_running:
logger.info("Still waiting on {} containers to spin up".format(len(not_running)))
else:
logger.info("Finished spinning up containers")
all_running = True
break
if not all_running:
raise StandardError(
"Timed out waiting to spin up all containers."
)
logger.info("Successfully booted up {} containers.".format(number_of_containers))
# Generate .txt files containing IP addresses and task arns
ip_list_string = " ".join(ip_addresses)
logger.info("Container IP list: {}".format(ip_list_string))
ip_list_file = open("pytest_container_ip_list.txt", "w")
ip_list_file.write(ip_list_string)
ip_list_file.close()
task_arn_list_string = " ".join(task_arns)
logger.info("Container task arn list: {}".format(task_arn_list_string))
task_arn_file = open("pytest_container_task_arns.txt", "w")
task_arn_file.write(task_arn_list_string)
task_arn_file.close()
def terminate_containers(self, task_arns, reason):
"""
Terminates containers based on a list of task_arns.
"""
for task_arn in task_arns:
response = self.ecs.stop_task(
cluster=self.cluster_name,
task=task_arn,
reason=reason
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="PytestContainerManager, manages ECS containers in an AWS cluster."
)
parser.add_argument('--region', '-g', default='us-east-1',
help="AWS region where ECS infrastructure lives. Defaults to us-east-1")
parser.add_argument('--cluster', '-c', default="jenkins-worker-containers",
help="AWS Cluster name where the containers live. Defaults to"
"the testeng cluster: jenkins-worker-containers")
parser.add_argument('--action', '-a', choices=['up', 'down'], default=None,
help="Action for PytestContainerManager to perform. "
"Either up for spinning up AWS ECS containers or down for stopping them")
# Spinning up containers
parser.add_argument('--num_containers', '-n', type=int, default=None,
help="Number of containers to spin up")
parser.add_argument('--task_name', '-t', default=None,
help="Name of the task definition for spinning up workers")
parser.add_argument('--subnets', '-s', nargs='+', default=None,
help="List of subnets for the containers to exist in")
parser.add_argument('--security_groups', '-sg', nargs='+', default=None,
help="List of security groups to apply to the containers")
parser.add_argument('--public_ip', choices=['ENABLED', 'DISABLED'],
default='DISABLED', help="Whether the containers should have a public IP")
parser.add_argument('--launch_type', default='FARGATE', choices=['EC2', 'FARGATE'],
help="ECS launch type of container. Defaults to FARGATE")
# Terminating containers
parser.add_argument('--task_arns', '-arns', nargs='+', default=None,
help="Task arns to terminate")
parser.add_argument('--reason', '-r', default="Finished executing tests",
help="Reason for terminating containers")
args = parser.parse_args()
containerManager = PytestContainerManager(args.region, args.cluster)
if args.action == 'up':
containerManager.spin_up_containers(
args.num_containers,
args.task_name,
args.subnets,
args.security_groups,
args.public_ip,