Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/azure-cli/azure/cli/command_modules/storage/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def load_arguments(self, _): # pylint: disable=too-many-locals, too-many-statem
c.argument('enable_blob_geo_priority_replication', arg_type=get_three_state_flag(),
options_list=['--enable-blob-geo-priority-replication', '--blob-geo-sla'],
help='Indicates whether Blob Geo Priority Replication is enabled for the storage account.')
c.argument('publish_ipv6_endpoint', arg_type=get_three_state_flag(),
arg_group='IPv6 Endpoint', is_preview=True,
help='A boolean flag which indicates whether IPv6 storage endpoints are to be published.')

with self.argument_context('storage account private-endpoint-connection',
resource_type=ResourceType.MGMT_STORAGE) as c:
Expand Down Expand Up @@ -547,6 +550,9 @@ def load_arguments(self, _): # pylint: disable=too-many-locals, too-many-statem
c.argument('enable_blob_geo_priority_replication', arg_type=get_three_state_flag(),
options_list=['--enable-blob-geo-priority-replication', '--blob-geo-sla'],
help='Indicates whether Blob Geo Priority Replication is enabled for the storage account.')
c.argument('publish_ipv6_endpoint', arg_type=get_three_state_flag(),
arg_group='IPv6 Endpoint', is_preview=True,
help='A boolean flag which indicates whether IPv6 storage endpoints are to be published.')

for scope in ['storage account create', 'storage account update']:
with self.argument_context(scope, arg_group='Customer managed key',
Expand Down Expand Up @@ -659,10 +665,12 @@ def load_arguments(self, _): # pylint: disable=too-many-locals, too-many-statem
c.argument('account_name', acct_name_type, id_part=None)

with self.argument_context('storage account network-rule', resource_type=ResourceType.MGMT_STORAGE) as c:
from ._validators import validate_ip_address
from ._validators import validate_ip_address, validate_ipv6_address
c.argument('account_name', acct_name_type, id_part=None)
c.argument('ip_address', nargs='*', help='IPv4 address or CIDR range. Can supply a list: --ip-address ip1 '
'[ip2]...', validator=validate_ip_address)
c.argument('ipv6_address', nargs='*', help='IPv6 address or CIDR range. Can supply a list: --ipv6-address ip1 '
'[ip2]...', validator=validate_ipv6_address, is_preview=True)
c.argument('subnet', help='Name or ID of subnet. If name is supplied, `--vnet-name` must be supplied.')
c.argument('vnet_name', help='Name of a virtual network.', validator=validate_subnet)
c.argument('action', action_type)
Expand Down
33 changes: 26 additions & 7 deletions src/azure-cli/azure/cli/command_modules/storage/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2122,22 +2122,41 @@ def validate_fs_file_set_expiry(namespace):
pass


def validate_ip_address(namespace):
# if there are overlapping ip ranges, throw an exception
ip_address = namespace.ip_address

def _find_ip_address_overlap(ip_address, ipv6=False):
if not ip_address:
return

from azure.cli.core.azclierror import InvalidArgumentValueError
ip_address_networks = [ip_network(ip) for ip in ip_address]
if ipv6:
ipv4_address_in_ipv6 = [ip for ip in ip_address_networks if ip.version != 6]
if ipv4_address_in_ipv6:
raise InvalidArgumentValueError(f"ipv4 addresses {ipv4_address_in_ipv6} found in --ipv6-address")
else:
ipv6_address_in_ipv4 = [ip for ip in ip_address_networks if ip.version == 6]
if ipv6_address_in_ipv4:
raise InvalidArgumentValueError(f"ipv6 addresses {ipv6_address_in_ipv4} found in --ip-address")

error_str = "ipv6 addresses {} and {} provided are overlapping: --ipv6-address ip1 [ip2]..." if ipv6 else \
"ip addresses {} and {} provided are overlapping: --ip-address ip1 [ip2]..."
for idx, ip_address_network in enumerate(ip_address_networks):
for idx2, ip_address_network2 in enumerate(ip_address_networks):
if idx == idx2:
continue
if ip_address_network.overlaps(ip_address_network2):
from azure.cli.core.azclierror import InvalidArgumentValueError
raise InvalidArgumentValueError(f"ip addresses {ip_address_network} and {ip_address_network2} "
f"provided are overlapping: --ip_address ip1 [ip2]...")
raise InvalidArgumentValueError(error_str.format(ip_address_network, ip_address_network2))


def validate_ip_address(namespace):
# if there are overlapping ip ranges, throw an exception
ip_address = namespace.ip_address
_find_ip_address_overlap(ip_address=ip_address, ipv6=False)


def validate_ipv6_address(namespace):
# if there are overlapping ip ranges, throw an exception
ipv6_address = namespace.ipv6_address
_find_ip_address_overlap(ip_address=ipv6_address, ipv6=True)


# pylint: disable=too-few-public-methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def create_storage_account(cmd, resource_group_name, account_name, sku=None, loc
immutability_period_since_creation_in_days=None, immutability_policy_state=None,
allow_protected_append_writes=None, public_network_access=None, dns_endpoint_type=None,
enable_smb_oauth=None, zones=None, zone_placement_policy=None,
enable_blob_geo_priority_replication=None):
enable_blob_geo_priority_replication=None, publish_ipv6_endpoint=None):
StorageAccountCreateParameters, Kind, Sku, CustomDomain, AccessTier, Identity, Encryption, NetworkRuleSet = \
cmd.get_models('StorageAccountCreateParameters', 'Kind', 'Sku', 'CustomDomain', 'AccessTier', 'Identity',
'Encryption', 'NetworkRuleSet')
Expand Down Expand Up @@ -325,6 +325,12 @@ def create_storage_account(cmd, resource_group_name, account_name, sku=None, loc
GeoPriorityReplicationStatus = cmd.get_models('GeoPriorityReplicationStatus')
params.geo_priority_replication_status = GeoPriorityReplicationStatus(is_blob_enabled=enable_blob_geo_priority_replication)

if publish_ipv6_endpoint is not None:
DualStackEndpointPreference = cmd.get_models('DualStackEndpointPreference')
params.dual_stack_endpoint_preference = DualStackEndpointPreference(
publish_ipv6_endpoint=publish_ipv6_endpoint
)

return scf.storage_accounts.begin_create(resource_group_name, account_name, params)


Expand Down Expand Up @@ -420,7 +426,7 @@ def update_storage_account(cmd, instance, sku=None, tags=None, custom_domain=Non
immutability_period_since_creation_in_days=None, immutability_policy_state=None,
allow_protected_append_writes=None, public_network_access=None, upgrade_to_storagev2=None,
yes=None, enable_smb_oauth=None, zones=None, zone_placement_policy=None,
enable_blob_geo_priority_replication=None):
enable_blob_geo_priority_replication=None, publish_ipv6_endpoint=None):
StorageAccountUpdateParameters, Sku, CustomDomain, AccessTier, Identity, Encryption, NetworkRuleSet, Kind = \
cmd.get_models('StorageAccountUpdateParameters', 'Sku', 'CustomDomain', 'AccessTier', 'Identity', 'Encryption',
'NetworkRuleSet', 'Kind')
Expand Down Expand Up @@ -734,6 +740,12 @@ def update_storage_account(cmd, instance, sku=None, tags=None, custom_domain=Non
GeoPriorityReplicationStatus = cmd.get_models('GeoPriorityReplicationStatus')
params.geo_priority_replication_status = GeoPriorityReplicationStatus(is_blob_enabled=enable_blob_geo_priority_replication)

if publish_ipv6_endpoint is not None:
DualStackEndpointPreference = cmd.get_models('DualStackEndpointPreference')
params.dual_stack_endpoint_preference = DualStackEndpointPreference(
publish_ipv6_endpoint=publish_ipv6_endpoint
)

return params


Expand All @@ -746,11 +758,12 @@ def list_network_rules(client, resource_group_name, account_name):


def add_network_rule(cmd, client, resource_group_name, account_name, action='Allow', subnet=None,
vnet_name=None, ip_address=None, tenant_id=None, resource_id=None): # pylint: disable=unused-argument
vnet_name=None, ip_address=None, ipv6_address=None, tenant_id=None, resource_id=None): # pylint: disable=unused-argument
sa = client.get_properties(resource_group_name, account_name)
rules = sa.network_rule_set
if not subnet and not ip_address:
if not subnet and not ip_address and not ipv6_address:
logger.warning('No subnet or ip address supplied.')

if subnet:
from azure.mgmt.core.tools import is_valid_resource_id
if not is_valid_resource_id(subnet):
Expand All @@ -761,22 +774,13 @@ def add_network_rule(cmd, client, resource_group_name, account_name, action='All
rules.virtual_network_rules = [r for r in rules.virtual_network_rules
if r.virtual_network_resource_id.lower() != subnet.lower()]
rules.virtual_network_rules.append(VirtualNetworkRule(virtual_network_resource_id=subnet, action=action))

if ip_address:
IpRule = cmd.get_models('IPRule')
if not rules.ip_rules:
rules.ip_rules = []
for ip in ip_address:
to_modify = True
for x in rules.ip_rules:
existing_ip_network = ip_network(x.ip_address_or_range)
new_ip_network = ip_network(ip)
if new_ip_network.overlaps(existing_ip_network):
logger.warning("IP/CIDR %s overlaps with %s, which exists already. Not adding duplicates.",
ip, x.ip_address_or_range)
to_modify = False
break
if to_modify:
rules.ip_rules.append(IpRule(ip_address_or_range=ip, action=action))
rules.ip_rules = _process_add_ip(cmd, ip_address, rules.ip_rules, action=action, ipv6=False)

if ipv6_address:
rules.ipv6_rules = _process_add_ip(cmd, ipv6_address, rules.ipv6_rules, action=action, ipv6=True)

if resource_id:
ResourceAccessRule = cmd.get_models('ResourceAccessRule')
if not rules.resource_access_rules:
Expand All @@ -790,7 +794,26 @@ def add_network_rule(cmd, client, resource_group_name, account_name, action='All
return client.update(resource_group_name, account_name, params)


def remove_network_rule(cmd, client, resource_group_name, account_name, ip_address=None, subnet=None,
def _process_add_ip(cmd, ip_address, ip_rules, action, ipv6=False):
IpRule = cmd.get_models('IPRule')
if not ip_rules:
ip_rules = []
for ip in ip_address:
to_modify = True
for x in ip_rules:
existing_ip_network = ip_network(x.ip_address_or_range)
new_ip_network = ip_network(ip)
if new_ip_network.overlaps(existing_ip_network):
logger.warning("IP%s/CIDR %s overlaps with %s, which exists already. Not adding duplicates.",
"v6" if ipv6 else "v4", ip, x.ip_address_or_range)
to_modify = False
break
if to_modify:
ip_rules.append(IpRule(ip_address_or_range=ip, action=action))
return ip_rules


def remove_network_rule(cmd, client, resource_group_name, account_name, ip_address=None, ipv6_address=None, subnet=None,
vnet_name=None, tenant_id=None, resource_id=None): # pylint: disable=unused-argument
sa = client.get_properties(resource_group_name, account_name)
rules = sa.network_rule_set
Expand All @@ -802,6 +825,11 @@ def remove_network_rule(cmd, client, resource_group_name, account_name, ip_addre
rules.ip_rules = list(filter(lambda x: all(ip_network(x.ip_address_or_range) != i for i in to_remove),
rules.ip_rules))

if ipv6_address:
to_remove = [ip_network(x) for x in ipv6_address]
rules.ipv6_rules = list(filter(lambda x: all(ip_network(x.ip_address_or_range) != i for i in to_remove),
rules.ipv6_rules))

if resource_id:
rules.resource_access_rules = [x for x in rules.resource_access_rules if
not (x.tenant_id == tenant_id and x.resource_id == resource_id)]
Expand Down
Loading