mirror of https://github.com/llvm/torch-mlir
[Torch] Implements TorchToLinalg lowering of torch.ops.aten._weight_norm_interface (#3538)
Resolves https://github.com/nod-ai/SHARK-Turbine/issues/757. Adds TorchToLinalg lowering for `Aten_WeightNormInterfaceOp`. --------- Co-authored-by: Ubuntu <rbhowmik@RohanBhowmikVM.judsoscro3wupi0qm4bjlj5m3b.bx.internal.cloudapp.net>pull/3550/head
parent
714270a922
commit
0791a8860c
|
@ -9118,6 +9118,32 @@ def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_WeightNormInterfaceOp : Torch_Op<"aten._weight_norm_interface", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$v,
|
||||
AnyTorchTensorType:$g,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result0,
|
||||
AnyTorchOptionalTensorType:$result1
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_WeightNormInterfaceOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 2);
|
||||
}
|
||||
void Aten_WeightNormInterfaceOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -9739,6 +9739,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._weight_norm_interface\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10934,6 +10940,55 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._weight_norm_interface\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.tuple<int, int> {\n"
|
||||
" %int6 = torch.constant.int 6\n"
|
||||
" %int9 = torch.constant.int 9\n"
|
||||
" %int7 = torch.constant.int 7\n"
|
||||
" %int10 = torch.constant.int 10\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %0 = torch.prim.Uninitialized : !torch.tuple<int, int>\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %3 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.eq.int %2#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %7:2 = torch.prim.If %6 -> (!torch.bool, !torch.tuple<int, int>) {\n"
|
||||
" %9 = torch.prim.TupleConstruct %1#1, %int7 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" torch.prim.If.yield %true, %9 : !torch.bool, !torch.tuple<int, int>\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.aten.eq.int %2#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.tuple<int, int>) {\n"
|
||||
" %11 = torch.prim.TupleConstruct %1#1, %int6 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" torch.prim.If.yield %true, %11 : !torch.bool, !torch.tuple<int, int>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" %8 = torch.prim.If %7#0 -> (!torch.tuple<int, int>) {\n"
|
||||
" torch.prim.If.yield %7#1 : !torch.tuple<int, int>\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.prim.TupleConstruct %1#1, %2#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" torch.prim.If.yield %9 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" return %8 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -5553,6 +5553,63 @@ class DecomposeAtenInstanceNormOp
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAten_WeightNormInterfaceOp
|
||||
: public OpRewritePattern<Aten_WeightNormInterfaceOp> {
|
||||
using OpRewritePattern<Aten_WeightNormInterfaceOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_WeightNormInterfaceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value v = op.getV();
|
||||
Value g = op.getG();
|
||||
Value dim = op.getDim();
|
||||
|
||||
auto inputType = cast<BaseTensorType>(v.getType());
|
||||
if (!inputType.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "expected input to have sizes");
|
||||
|
||||
if (!cast<ConstantIntOp>(dim.getDefiningOp()))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a ConstantIntOp");
|
||||
|
||||
auto sizes = inputType.getSizes();
|
||||
SmallVector<Value> keepDims;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(sizes.size()); ++i) {
|
||||
if (i !=
|
||||
static_cast<int64_t>(dim.getDefiningOp<ConstantIntOp>().getValue()))
|
||||
keepDims.push_back(
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
|
||||
Value ord =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
|
||||
Value keepdim =
|
||||
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(true));
|
||||
Value dtypeNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
||||
keepDims);
|
||||
|
||||
Value norm = rewriter.create<AtenLinalgVectorNormOp>(
|
||||
loc, v.getType(), v, ord, dimList, keepdim, dtypeNone);
|
||||
|
||||
auto vShape = rewriter.create<AtenSizeOp>(
|
||||
loc, Torch::ListType::get(rewriter.getI64Type()), v);
|
||||
|
||||
Value gDivNorm =
|
||||
rewriter.create<AtenDivTensorOp>(loc, g.getType(), g, norm);
|
||||
Value broadcastedGDivNorm =
|
||||
rewriter.create<AtenBroadcastToOp>(loc, v.getType(), gDivNorm, vShape);
|
||||
Value vMulBroadcastedGDivNorm = rewriter.create<AtenMulTensorOp>(
|
||||
loc, v.getType(), v, broadcastedGDivNorm);
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value>{vMulBroadcastedGDivNorm, norm});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenNativeLayerNormOp
|
||||
: public OpRewritePattern<AtenNativeLayerNormOp> {
|
||||
|
@ -7194,7 +7251,6 @@ public:
|
|||
rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>(
|
||||
op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode,
|
||||
sparse, perSampleWeights, includeLastOffset, paddingIdx);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -8704,6 +8760,8 @@ public:
|
|||
legalOpsSet.clear();
|
||||
legalOpsSet.insert(legalOps.begin(), legalOps.end());
|
||||
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxOp>(patterns);
|
||||
|
|
|
@ -418,6 +418,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
});
|
||||
target.addIllegalOp<AtenAddcmulOp>();
|
||||
target.addIllegalOp<AtenAddcdivOp>();
|
||||
target.addIllegalOp<Aten_WeightNormInterfaceOp>();
|
||||
target.addIllegalOp<AtenInstanceNormOp>();
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
|
|
|
@ -469,6 +469,7 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"ViewSizeFromOtherTensor_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
}
|
||||
|
||||
FX_IMPORTER_CRASHING_SET = {
|
||||
|
@ -2629,6 +2630,7 @@ ONNX_XFAIL_SET = {
|
|||
"ViewNoChange1dModule_basic",
|
||||
"ViewNoChange2dModule_basic",
|
||||
"ViewNoChange3dModule_basic",
|
||||
"WeightNormInterfaceModule_basic",
|
||||
"_Convolution2DAllFalseModule_basic",
|
||||
"_Convolution2DBenchmarkModule_basic",
|
||||
"_Convolution2DCudnnModule_basic",
|
||||
|
|
|
@ -1771,6 +1771,9 @@ def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int
|
|||
def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||
return upstream_shape_functions.unary(input)
|
||||
|
||||
def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int = 0) -> Tuple[List[int], List[int]]:
|
||||
return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g)
|
||||
|
||||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||
|
||||
|
@ -2544,6 +2547,18 @@ def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_
|
|||
input_rank, input_dtype = input_rank_dtype
|
||||
return input_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={*all_integer_dtypes()}))
|
||||
def aten〇_weight_norm_interface〡dtype(v_rank_dtype: Tuple[int, int], g_rank_dtype: Tuple[int, int], dim: int = 0) -> Tuple[int, int]:
|
||||
v_rank, v_dtype = v_rank_dtype
|
||||
g_rank, g_dtype = g_rank_dtype
|
||||
assert v_dtype == g_dtype
|
||||
assert not is_integer_dtype(g_dtype)
|
||||
if g_dtype == torch.complex128:
|
||||
return v_dtype, torch.float64
|
||||
elif g_dtype == torch.complex64:
|
||||
return v_dtype, torch.float32
|
||||
return v_dtype, g_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -732,6 +732,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)"
|
||||
)
|
||||
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
|
||||
|
|
|
@ -726,3 +726,28 @@ class RenormModuleFloat32DynamicDims(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: RenormModuleFloat32DynamicDims())
|
||||
def RenormModuleFloat32DynamicDims_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class WeightNormInterfaceModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dim = 2
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, v, g):
|
||||
return torch.ops.aten._weight_norm_interface(v, g, self.dim)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: WeightNormInterfaceModule())
|
||||
def WeightNormInterfaceModule_basic(module, tu: TestUtils):
|
||||
g = tu.rand(3, 10, 10)
|
||||
v = tu.rand(1, 1, 10)
|
||||
module.forward(g, v)
|
||||
|
|
Loading…
Reference in New Issue