
# Autogenerated by mlir-tblgen; don't manually edit.

from enum import IntEnum, auto, IntFlag
from jaxlib.mlir.dialects._ods_common import _cext as _ods_cext
from jaxlib.mlir.ir import register_attribute_builder
_ods_ir = _ods_cext.ir

class Dimension(IntEnum):
    """a dimension, either 'x', 'y', or 'z'"""

    x = 0
    y = 1
    z = 2

    def __str__(self):
        if self is Dimension.x:
            return "x"
        if self is Dimension.y:
            return "y"
        if self is Dimension.z:
            return "z"
        raise ValueError("Unknown Dimension enum entry.")



@register_attribute_builder("MosaicGPU_Dimension")
def _mosaicgpu_dimension(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class SwizzlingMode(IntEnum):
    """What swizzling to use for a memory access."""

    kNoSwizzle = 16
    k32ByteSwizzle = 32
    k64ByteSwizzle = 64
    k128ByteSwizzle = 128

    def __str__(self):
        if self is SwizzlingMode.kNoSwizzle:
            return "kNoSwizzle"
        if self is SwizzlingMode.k32ByteSwizzle:
            return "k32ByteSwizzle"
        if self is SwizzlingMode.k64ByteSwizzle:
            return "k64ByteSwizzle"
        if self is SwizzlingMode.k128ByteSwizzle:
            return "k128ByteSwizzle"
        raise ValueError("Unknown SwizzlingMode enum entry.")



@register_attribute_builder("MosaicGPU_SwizzlingMode")
def _mosaicgpu_swizzlingmode(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

class WGMMALayout(IntEnum):
    """The layout of the tiles of a WGMMA operation"""

    RowMajor = 0
    ColumnMajor = 1

    def __str__(self):
        if self is WGMMALayout.RowMajor:
            return "RowMajor"
        if self is WGMMALayout.ColumnMajor:
            return "ColumnMajor"
        raise ValueError("Unknown WGMMALayout enum entry.")



@register_attribute_builder("MosaicGPU_WGMMALayout")
def _mosaicgpu_wgmmalayout(x, context):
    return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))

