Add aten.isclose support and its torch-to-tosa lowering (#2512)

Add aten.isclose op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests


To test e2e tosa lowering:
`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
pull/2511/head
Ze Zhang 2023-10-16 09:44:53 -07:00 committed by GitHub
parent e649e06b7b
commit f2c53b8ca5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 0 deletions

View File

@ -18,6 +18,8 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic", "UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
} }
TORCHDYNAMO_XFAIL_SET = { TORCHDYNAMO_XFAIL_SET = {
@ -928,6 +930,8 @@ STABLEHLO_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development # Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet. # and very few tests work yet.
TOSA_PASS_SET = { TOSA_PASS_SET = {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic", "TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic", "TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic",

View File

@ -4162,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [
}]; }];
} }
def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
Torch_FloatType:$rtol,
Torch_FloatType:$atol,
Torch_BoolType:$equal_nan
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIscloseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -3920,6 +3920,59 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
AtenIscloseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// check args
double rtol, atol;
bool equalNan;
if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol)))
return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant");
if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol)))
return rewriter.notifyMatchFailure(op, "atol must be a scalar constant");
if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan)))
return rewriter.notifyMatchFailure(
op, "unimplemented: equal_nan is expected to be false");
// check tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
if (!selfType || !otherType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
if (!selfType.getElementType().isa<mlir::FloatType>() ||
!otherType.getElementType().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only FP element type is supported");
}
auto rhsSubOp = rewriter.create<tosa::SubOp>(
op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther());
auto rhsAbsOp =
rewriter.create<tosa::AbsOp>(op->getLoc(), selfType, rhsSubOp);
auto lhsAbsOp =
rewriter.create<tosa::AbsOp>(op->getLoc(), otherType, adaptor.getOther());
auto rtolConstOp =
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(rtol));
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), otherType,
rtolConstOp, lhsAbsOp, /*shift=*/0);
auto atolConstOp =
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(atol));
auto addOp =
rewriter.create<tosa::AddOp>(op->getLoc(), otherType, atolConstOp, mulOp);
auto outType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(op, outType, addOp,
rhsAbsOp);
return success();
}
template <> template <>
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite( LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
AtenClampOp op, OpAdaptor adaptor, AtenClampOp op, OpAdaptor adaptor,
@ -5134,6 +5187,7 @@ public:
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -7480,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -9093,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n" " %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n" " return %int11 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"

View File

@ -844,6 +844,9 @@ def atenltTensor〡shape(self: List[int], other: List[int]) -> List[int]:
def atenleTensor〡shape(self: List[int], other: List[int]) -> List[int]: def atenleTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
def atenisclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenunsqueeze〡shape(self: List[int], dim: int) -> List[int]: def atenunsqueeze〡shape(self: List[int], dim: int) -> List[int]:
return upstream_shape_functions.unsqueeze(self, dim) return upstream_shape_functions.unsqueeze(self, dim)
@ -2171,6 +2174,10 @@ def atenlogical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
def atenlogical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def atenlogical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool return torch.bool
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
def atenisclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int:
return torch.bool
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
def atenscaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int: def atenscaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
_, query_dtype = query_rank_dtype _, query_dtype = query_rank_dtype

View File

@ -342,6 +342,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)")
emit("aten::view_as_real : (Tensor) -> (Tensor)") emit("aten::view_as_real : (Tensor) -> (Tensor)")
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")
# Ops with dynamic number of outputs # Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")

View File

@ -4580,3 +4580,48 @@ class Add_Module(torch.nn.Module):
@register_test_case(module_factory=lambda: Add_Module()) @register_test_case(module_factory=lambda: Add_Module())
def Add_Module_basic(module, tu: TestUtils): def Add_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3)) module.forward(tu.rand(2, 3))
# ==============================================================================
class IscloseStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 5], torch.float32, True),
([5, 5], torch.float32, True),
])
def forward(self, x, y):
return torch.isclose(x, y)
@register_test_case(module_factory=lambda: IscloseStaticModule())
def IscloseStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5), tu.rand(5, 5))
# ==============================================================================
class IscloseStaticModuleTrue(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer('tensor', torch.ones(1))
@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.isclose(x, self.tensor)
@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))

View File

@ -1155,3 +1155,32 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32> %0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32> return %0 : !torch.vtensor<[2, 4],f32>
} }
// -----
// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08
// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
%float1.000000e-08 = torch.constant.float 1.000000e-08
%float1.000000e-05 = torch.constant.float 1.000000e-05
%false = torch.constant.bool false
%0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1>
return %0 : !torch.vtensor<[5,5],i1>
}