[Torch] Implements TorchToLinalg lowering of torch.ops.aten._weight_norm_interface (#3538)

Resolves https://github.com/nod-ai/SHARK-Turbine/issues/757.

Adds TorchToLinalg lowering for `Aten_WeightNormInterfaceOp`.

---------

Co-authored-by: Ubuntu <rbhowmik@RohanBhowmikVM.judsoscro3wupi0qm4bjlj5m3b.bx.internal.cloudapp.net>
pull/3550/head
rohan-tan-bhowmik 2024-07-16 10:39:12 -07:00 committed by GitHub
parent 714270a922
commit 0791a8860c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 184 additions and 1 deletions

View File

@ -9118,6 +9118,32 @@ def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [
}];
}
def Torch_Aten_WeightNormInterfaceOp : Torch_Op<"aten._weight_norm_interface", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$v,
AnyTorchTensorType:$g,
Torch_IntType:$dim
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_WeightNormInterfaceOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void Aten_WeightNormInterfaceOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -9739,6 +9739,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._weight_norm_interface\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10934,6 +10940,55 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._weight_norm_interface\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.tuple<int, int> {\n"
" %int6 = torch.constant.int 6\n"
" %int9 = torch.constant.int 9\n"
" %int7 = torch.constant.int 7\n"
" %int10 = torch.constant.int 10\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %0 = torch.prim.Uninitialized : !torch.tuple<int, int>\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.eq.int %2#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
" %7:2 = torch.prim.If %6 -> (!torch.bool, !torch.tuple<int, int>) {\n"
" %9 = torch.prim.TupleConstruct %1#1, %int7 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %true, %9 : !torch.bool, !torch.tuple<int, int>\n"
" } else {\n"
" %9 = torch.aten.eq.int %2#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.tuple<int, int>) {\n"
" %11 = torch.prim.TupleConstruct %1#1, %int6 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %true, %11 : !torch.bool, !torch.tuple<int, int>\n"
" } else {\n"
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.tuple<int, int>\n"
" }\n"
" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.tuple<int, int>\n"
" }\n"
" %8 = torch.prim.If %7#0 -> (!torch.tuple<int, int>) {\n"
" torch.prim.If.yield %7#1 : !torch.tuple<int, int>\n"
" } else {\n"
" %9 = torch.prim.TupleConstruct %1#1, %2#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" torch.prim.If.yield %9 : !torch.tuple<int, int>\n"
" }\n"
" return %8 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -5553,6 +5553,63 @@ class DecomposeAtenInstanceNormOp
};
} // namespace
namespace {
class DecomposeAten_WeightNormInterfaceOp
: public OpRewritePattern<Aten_WeightNormInterfaceOp> {
using OpRewritePattern<Aten_WeightNormInterfaceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_WeightNormInterfaceOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value v = op.getV();
Value g = op.getG();
Value dim = op.getDim();
auto inputType = cast<BaseTensorType>(v.getType());
if (!inputType.hasSizes())
return rewriter.notifyMatchFailure(op, "expected input to have sizes");
if (!cast<ConstantIntOp>(dim.getDefiningOp()))
return rewriter.notifyMatchFailure(op, "dim is not a ConstantIntOp");
auto sizes = inputType.getSizes();
SmallVector<Value> keepDims;
for (int64_t i = 0; i < static_cast<int64_t>(sizes.size()); ++i) {
if (i !=
static_cast<int64_t>(dim.getDefiningOp<ConstantIntOp>().getValue()))
keepDims.push_back(
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
}
Value ord =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
Value keepdim =
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
Value dtypeNone = rewriter.create<ConstantNoneOp>(loc);
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
keepDims);
Value norm = rewriter.create<AtenLinalgVectorNormOp>(
loc, v.getType(), v, ord, dimList, keepdim, dtypeNone);
auto vShape = rewriter.create<AtenSizeOp>(
loc, Torch::ListType::get(rewriter.getI64Type()), v);
Value gDivNorm =
rewriter.create<AtenDivTensorOp>(loc, g.getType(), g, norm);
Value broadcastedGDivNorm =
rewriter.create<AtenBroadcastToOp>(loc, v.getType(), gDivNorm, vShape);
Value vMulBroadcastedGDivNorm = rewriter.create<AtenMulTensorOp>(
loc, v.getType(), v, broadcastedGDivNorm);
rewriter.replaceOp(op, ArrayRef<Value>{vMulBroadcastedGDivNorm, norm});
return success();
}
};
} // namespace
namespace {
class DecomposeAtenNativeLayerNormOp
: public OpRewritePattern<AtenNativeLayerNormOp> {
@ -7194,7 +7251,6 @@ public:
rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>(
op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode,
sparse, perSampleWeights, includeLastOffset, paddingIdx);
return success();
}
};
@ -8704,6 +8760,8 @@ public:
legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);

View File

@ -418,6 +418,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
});
target.addIllegalOp<AtenAddcmulOp>();
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<Aten_WeightNormInterfaceOp>();
target.addIllegalOp<AtenInstanceNormOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();

View File

@ -469,6 +469,7 @@ FX_IMPORTER_XFAIL_SET = {
"UpSampleNearest2dDynamicFactor_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"WeightNormInterfaceModule_basic",
}
FX_IMPORTER_CRASHING_SET = {
@ -2629,6 +2630,7 @@ ONNX_XFAIL_SET = {
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"WeightNormInterfaceModule_basic",
"_Convolution2DAllFalseModule_basic",
"_Convolution2DBenchmarkModule_basic",
"_Convolution2DCudnnModule_basic",

View File

@ -1771,6 +1771,9 @@ def atennative_group_norm〡shape(input: List[int], weight: Optional[List[int
def ateninstance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.unary(input)
def aten_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int = 0) -> Tuple[List[int], List[int]]:
return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g)
def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step)
@ -2544,6 +2547,18 @@ def ateninstance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_
input_rank, input_dtype = input_rank_dtype
return input_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={*all_integer_dtypes()}))
def aten_weight_norm_interface〡dtype(v_rank_dtype: Tuple[int, int], g_rank_dtype: Tuple[int, int], dim: int = 0) -> Tuple[int, int]:
v_rank, v_dtype = v_rank_dtype
g_rank, g_dtype = g_rank_dtype
assert v_dtype == g_dtype
assert not is_integer_dtype(g_dtype)
if g_dtype == torch.complex128:
return v_dtype, torch.float64
elif g_dtype == torch.complex64:
return v_dtype, torch.float32
return v_dtype, g_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenbernoulli_float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -732,6 +732,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)"
)
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

View File

@ -726,3 +726,28 @@ class RenormModuleFloat32DynamicDims(torch.nn.Module):
@register_test_case(module_factory=lambda: RenormModuleFloat32DynamicDims())
def RenormModuleFloat32DynamicDims_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 3))
# ==============================================================================
class WeightNormInterfaceModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dim = 2
@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, v, g):
return torch.ops.aten._weight_norm_interface(v, g, self.dim)
@register_test_case(module_factory=lambda: WeightNormInterfaceModule())
def WeightNormInterfaceModule_basic(module, tu: TestUtils):
g = tu.rand(3, 10, 10)
v = tu.rand(1, 1, 10)
module.forward(g, v)