mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add support for aten._unsafe_index_put.hacked_twin op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2392/head snapshot-20230811.927
parent
f0a8f273f7
commit
e61ef1ee54
|
@ -273,6 +273,9 @@ TORCHDYNAMO_XFAIL_SET = {
|
|||
# ERROR: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {}
|
||||
"ScatterValueIntModule_basic",
|
||||
|
||||
# AssertionError: Unregistered operation: torch.aten._unsafe_index_put
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
|
||||
}
|
||||
|
||||
TORCHDYNAMO_CRASHING_SET = {
|
||||
|
@ -1431,4 +1434,5 @@ LTC_XFAIL_SET = {
|
|||
"ScatterValueIntModule_basic",
|
||||
"IndexTensorNegativeIndexModule_basic",
|
||||
"UniformStaticShapeModule_basic",
|
||||
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
|
||||
}
|
||||
|
|
|
@ -4324,6 +4324,32 @@ def Torch_AtenIndexPut_HackedTwinOp : Torch_Op<"aten.index_put_.hacked_twin", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_UnsafeIndexPutHackedTwinOp : Torch_Op<"aten._unsafe_index_put.hacked_twin", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTensorType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_UnsafeIndexPutHackedTwinOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void Aten_UnsafeIndexPutHackedTwinOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinearOp : Torch_Op<"aten.linear", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -8552,6 +8552,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._unsafe_index_put.hacked_twin\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.tuple<int, int>, %arg3: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._index_put_impl\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<optional<tuple<int, int>>>, %arg2: !torch.tuple<int, int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -3254,6 +3254,25 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl`
|
||||
// op.
|
||||
class DecomposeAten_UnsafeIndexPutHackedTwinOp
|
||||
: public OpRewritePattern<Aten_UnsafeIndexPutHackedTwinOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||
rewriter.replaceOpWithNewOp<Aten_IndexPutImplOp>(
|
||||
op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(),
|
||||
op.getAccumulate(),
|
||||
/*unsafe=*/cstFalse);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.pad` op into `aten.constantPadNd` op.
|
||||
class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
|
||||
|
@ -4889,6 +4908,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DeomposeAtenNativeDropoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexPutHackedTwinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_UnsafeIndexPutHackedTwinOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||
|
|
|
@ -443,6 +443,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenNativeDropoutOp>();
|
||||
target.addIllegalOp<AtenNewEmptyOp>();
|
||||
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
||||
target.addIllegalOp<AtenPadOp>();
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
|
|
|
@ -274,6 +274,7 @@ def _lower_mlir_module(verbose, output_type, module):
|
|||
print("Torch Backend IR")
|
||||
print(module)
|
||||
|
||||
# module.dump()
|
||||
if output_type == OutputType.TORCH:
|
||||
return module
|
||||
|
||||
|
@ -292,6 +293,7 @@ def _lower_mlir_module(verbose, output_type, module):
|
|||
module,
|
||||
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
|
||||
# module.dump()
|
||||
if verbose:
|
||||
print("\n====================")
|
||||
print("LINALG Backend IR")
|
||||
|
@ -446,6 +448,7 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
# mb.module.dump()
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
|
||||
" extra-library=" + extra_library_file_name + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
|
|
|
@ -1652,6 +1652,11 @@ def aten〇index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], ind
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_index_put_invocations)
|
||||
def aten〇_unsafe_index_put〇hacked_twin〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Tuple[int, int]], values_rank_dtype: Tuple[int, int], accumulate: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_index_put_invocations)
|
||||
def aten〇_index_put_impl〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: List[Optional[Tuple[int, int]]], values_rank_dtype: Tuple[int, int], accumulate: bool = False, unsafe: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
|
|
|
@ -353,6 +353,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
emit_with_mutating_variants(
|
||||
"aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
|
||||
emit("aten::_unsafe_index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)")
|
||||
|
||||
# Non-elementwise tensor compute ops
|
||||
emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)")
|
||||
|
|
|
@ -843,6 +843,35 @@ def IndexPutHackedTwin3DIntAccumulateModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.randint(10, 8, 6, high=1000), tu.randint(5, high=4),
|
||||
tu.randint(5, 8, 6, high=1000))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# UnsafeIndexPutHackedTwin tests are using the aten._unsafe_index_put.hacked_twin operator.
|
||||
|
||||
|
||||
class UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._unsafe_index_put(input, [index],
|
||||
value,
|
||||
accumulate=False)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule())
|
||||
def UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100), tu.randint(250, high=100), tu.rand(250))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ScatterSrcStaticModule(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue