mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for AtenNe[Tensor|Scalar]Op, AtenLog2Op, AtenBitwiseAndTensorOp, AtenSquareOp and AtenThresholdOp * Fix for Issue #532 - Mixed input types for few ops and updated few tests to use i32 instead of i64 Signed-off-by: Anup Gangwar <anup.gangwar@arm.com> Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>pull/587/head
parent
c1167853db
commit
756b75fb2d
|
@ -739,14 +739,14 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.mul(x, 8.0)
|
return torch.mul(x, 8.0)
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
||||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10, (3, 4)))
|
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||||
|
@ -1045,7 +1045,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
||||||
@export
|
@export
|
||||||
@annotate_args([
|
@annotate_args([
|
||||||
None,
|
None,
|
||||||
([-1, -1], torch.int64, True),
|
([-1, -1], torch.int32, True),
|
||||||
])
|
])
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.sub(x, 2.1, alpha=2)
|
return torch.sub(x, 2.1, alpha=2)
|
||||||
|
@ -1053,7 +1053,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
|
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
|
||||||
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
|
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10, (3, 4)))
|
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
||||||
|
@ -1072,8 +1072,7 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
||||||
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
|
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
||||||
class ElementwiseAddScalarIntModule(torch.nn.Module):
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -1085,9 +1084,26 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return torch.add(x, 3.0)
|
return torch.add(x, 3.0)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module())
|
||||||
|
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (3, 4)))
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAddScalarIntModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.add(x, 3.0)
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
|
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
|
||||||
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
|
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10, (3, 4)))
|
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
||||||
|
|
|
@ -12,6 +12,24 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class Threshold1dIntI32Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int32, True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.ops.aten.threshold(input, 1, 2)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
||||||
|
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(10, (4,), dtype=torch.int32))
|
||||||
|
|
||||||
|
|
||||||
class Threshold1dIntModule(torch.nn.Module):
|
class Threshold1dIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -99,4 +99,12 @@ TOSA_PASS_SET = {
|
||||||
"LayerNormNormalizeOverAllDimsModule_basic",
|
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||||
"PermuteModule_basic",
|
"PermuteModule_basic",
|
||||||
"PermuteNegativeIndexModule_basic",
|
"PermuteNegativeIndexModule_basic",
|
||||||
|
"ElementwiseLog2Module_basic",
|
||||||
|
"Threshold1dIntI32Module_basic",
|
||||||
|
"Threshold1dFloatModule_basic",
|
||||||
|
"Threshold2dFloatModule_basic",
|
||||||
|
"Threshold3dFloatModule_basic",
|
||||||
|
"ElementwiseSubScalarIntModule_basic",
|
||||||
|
"ElementwiseAddScalarIntModule_basic",
|
||||||
|
"ElementwiseMulScalarModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -111,38 +111,63 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||||
|
const int64_t &intValue) {
|
||||||
|
if (isFloat) {
|
||||||
|
return (doubleValue >= std::numeric_limits<T>::min()) &&
|
||||||
|
(doubleValue <= std::numeric_limits<T>::max());
|
||||||
|
} else {
|
||||||
|
assert(isInt);
|
||||||
|
return (intValue >= std::numeric_limits<T>::min()) &&
|
||||||
|
(intValue <= std::numeric_limits<T>::max());
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
// FIXME: This will eventually go into a Tosa*Utils file.
|
// FIXME: This will eventually go into a Tosa*Utils file.
|
||||||
LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
Operation *op, Value torchScalarValue,
|
Operation *op, Value torchScalarValue,
|
||||||
Value &tosaTensor, Type dtype) {
|
Value &tosaTensor, Type dtype,
|
||||||
|
llvm::ArrayRef<int64_t> dshape) {
|
||||||
|
// Retrieve a const float or int value but create the out Tensor with dtype.
|
||||||
|
double doubleValue;
|
||||||
|
auto isFloat =
|
||||||
|
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));
|
||||||
|
|
||||||
|
int64_t intValue;
|
||||||
|
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));
|
||||||
|
|
||||||
|
if (!isFloat && !isInt)
|
||||||
|
return op->emitError("Unable to extract the scalar constant");
|
||||||
|
|
||||||
if (dtype.isa<mlir::FloatType>()) {
|
if (dtype.isa<mlir::FloatType>()) {
|
||||||
double scalarValue;
|
tosaTensor = tosa::getConstTensor<float>(
|
||||||
|
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||||
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
|
.getValue();
|
||||||
return failure();
|
|
||||||
|
|
||||||
tosaTensor =
|
|
||||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
|
|
||||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||||
int64_t scalarValue;
|
|
||||||
|
|
||||||
if (!matchPattern(torchScalarValue, m_TorchConstantInt(&scalarValue)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto w = intType.getWidth();
|
auto w = intType.getWidth();
|
||||||
if (w != 32 && w != 64)
|
if (w != 32 && w != 64)
|
||||||
return op->emitError("Unsupported integer type") << intType;
|
return op->emitError("Unsupported integer type") << intType;
|
||||||
|
|
||||||
if (w == 32) {
|
if (w == 32) {
|
||||||
tosaTensor = tosa::getConstTensor<int32_t>(
|
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
rewriter, op, {static_cast<int32_t>(scalarValue)}, {})
|
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||||
.getValue();
|
"of destination type");
|
||||||
} else if (w == 64) {
|
}
|
||||||
|
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||||
|
: static_cast<int32_t>(intValue);
|
||||||
tosaTensor =
|
tosaTensor =
|
||||||
tosa::getConstTensor<int64_t>(rewriter, op, {scalarValue}, {})
|
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
||||||
.getValue();
|
} else if (w == 64) {
|
||||||
|
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||||
|
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||||
|
"of destination type");
|
||||||
|
}
|
||||||
|
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||||
|
tosaTensor =
|
||||||
|
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
||||||
}
|
}
|
||||||
return success();
|
|
||||||
} else
|
} else
|
||||||
return op->emitError("Usupported element type");
|
return op->emitError("Usupported element type");
|
||||||
|
|
||||||
|
@ -154,7 +179,7 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||||
Value &alphaTensor, Type dtype,
|
Value &alphaTensor, Type dtype,
|
||||||
bool checkForUnity) {
|
bool checkForUnity) {
|
||||||
if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor,
|
if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor,
|
||||||
dtype)))
|
dtype, {})))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
// `alpha` has not been specified.
|
// `alpha` has not been specified.
|
||||||
|
@ -183,44 +208,59 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.self();
|
Value lhs = adaptor.self();
|
||||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||||
Value rhs = adaptor.other();
|
Value rhs = adaptor.other();
|
||||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsType)
|
||||||
return op.emitError("Only Tensor types supported in TOSA");
|
return op.emitError("Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto lhsElemTy = lhsTy.getElementType();
|
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) {
|
||||||
if (!lhsElemTy.isIntOrFloat())
|
if (lhsElemTy.getWidth() > 32)
|
||||||
|
return op.emitError(
|
||||||
|
"Integers with widths greater than 32 are not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outType =
|
||||||
|
static_cast<Type>(
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
op.getType()))
|
||||||
|
.cast<TensorType>();
|
||||||
|
|
||||||
|
Type outElemTy = outType.getElementType();
|
||||||
|
if (!outElemTy.isIntOrFloat()) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"Only floating-point or integer datatype legalization supported");
|
"Only floating-point or integer datatype legalization supported");
|
||||||
|
}
|
||||||
|
|
||||||
Value rhsAsTensor;
|
Value rhsAsTensor;
|
||||||
if (!rhsTy) {
|
if (!rhsType) {
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
|
||||||
op.other(), rhsAsTensor, lhsElemTy)))
|
outElemTy, {})))
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA operation");
|
"conversion in TOSA operation");
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||||
|
|
||||||
// Handle alpha.
|
// Handle alpha.
|
||||||
Value alphaTensor;
|
Value alphaTensor;
|
||||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(),
|
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(),
|
||||||
alphaTensor, lhsElemTy, false)))
|
alphaTensor, outElemTy,
|
||||||
|
/*checkForUnity=*/false))) {
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
"alpha in conversion to TOSA operation");
|
"alpha in conversion to TOSA operation");
|
||||||
|
}
|
||||||
|
|
||||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||||
op.getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
op.getLoc(), rhsType ? rhsType : RankedTensorType::get({}, outElemTy),
|
||||||
rhsTensor, alphaTensor, /*shift*/ 0);
|
rhsTensor, alphaTensor, /*shift=*/0);
|
||||||
|
|
||||||
|
if (outElemTy.isa<mlir::FloatType>()) {
|
||||||
|
if (lhsType.getElementType() != outElemTy)
|
||||||
|
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
|
||||||
|
|
||||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
|
||||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
|
||||||
op,
|
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
|
||||||
op.getType()),
|
|
||||||
lhs, multTensor);
|
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
|
@ -251,23 +291,42 @@ public:
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"Only floating-point or integer datatype legalization supported");
|
"Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
|
// For bitwise operators, only integer datatype legalization is supported
|
||||||
|
if (lhsElemTy.isa<mlir::FloatType>() &&
|
||||||
|
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>()) {
|
||||||
|
return op.emitError("For bitwise operators, only integer datatype "
|
||||||
|
"legalization is supported");
|
||||||
|
}
|
||||||
|
|
||||||
Value rhsAsTensor;
|
Value rhsAsTensor;
|
||||||
if (!rhsTy) {
|
if (!rhsTy) {
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
|
||||||
op.other(), rhsAsTensor, lhsElemTy)))
|
lhsElemTy, {})))
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA operation");
|
"conversion in TOSA operation");
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||||
// There is no Lesser operator in TOSA
|
// There is no Lesser operator in TOSA.
|
||||||
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
auto resultOp = rewriter.create<TosaOpT>(
|
||||||
op,
|
op.getLoc(),
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
op.getType()),
|
op.getType()),
|
||||||
(swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor));
|
(swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor));
|
||||||
|
|
||||||
|
// There is no NE operator in TOSA.
|
||||||
|
if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||||
|
std::is_same<AtenOpT, AtenNeScalarOp>())
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(
|
||||||
|
op,
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
op.getType()),
|
||||||
|
resultOp.getResult());
|
||||||
|
else
|
||||||
|
rewriter.replaceOp(op, resultOp.getResult());
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -282,29 +341,44 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.self();
|
Value lhs = adaptor.self();
|
||||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||||
Value rhs = adaptor.other();
|
|
||||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsType)
|
||||||
return op.emitError("Only Tensor types supported in TOSA");
|
return op.emitError("Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
auto lhsElemTy = lhsTy.getElementType();
|
auto outType =
|
||||||
if (!lhsElemTy.isIntOrFloat())
|
static_cast<Type>(
|
||||||
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
op.getType()))
|
||||||
|
.cast<TensorType>();
|
||||||
|
|
||||||
|
Type outElemTy = outType.getElementType();
|
||||||
|
if (!outElemTy.isIntOrFloat())
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"Only floating-point or integer datatype legalization supported");
|
"Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
Value rhsAsTensor;
|
Value rhsTensor;
|
||||||
if (!rhsTy) {
|
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
rhsTensor = lhs;
|
||||||
op.other(), rhsAsTensor, lhsElemTy)))
|
} else {
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
Value rhsAsTensor;
|
||||||
"conversion in TOSA operation");
|
Value rhs = adaptor.other();
|
||||||
|
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||||
|
if (!rhsType) {
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(),
|
||||||
|
rhsAsTensor, outElemTy, {})))
|
||||||
|
return op.emitError(
|
||||||
|
"Currently only scalar constants are supported for "
|
||||||
|
"conversion in TOSA operation");
|
||||||
|
}
|
||||||
|
rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
|
||||||
|
|
||||||
if (lhsElemTy.isa<mlir::FloatType>() ||
|
if (outElemTy.isa<mlir::FloatType>() ||
|
||||||
lhsElemTy.isa<mlir::IntegerType>()) {
|
outElemTy.isa<mlir::IntegerType>()) {
|
||||||
|
if (lhsType.getElementType() != outElemTy)
|
||||||
|
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
@ -343,8 +417,8 @@ public:
|
||||||
|
|
||||||
Value rhsAsTensor;
|
Value rhsAsTensor;
|
||||||
if (!rhsTy) {
|
if (!rhsTy) {
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
|
||||||
op.other(), rhsAsTensor, lhsElemTy)))
|
lhsElemTy, {})))
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA operation");
|
"conversion in TOSA operation");
|
||||||
}
|
}
|
||||||
|
@ -807,8 +881,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value expTensor;
|
Value expTensor;
|
||||||
Value expScalar = op.exponent();
|
Value expScalar = op.exponent();
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), expScalar,
|
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
||||||
expTensor, selfTy.getElementType())))
|
selfTy.getElementType(), {})))
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA Pow operation");
|
"conversion in TOSA Pow operation");
|
||||||
|
|
||||||
|
@ -1584,19 +1658,19 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value otherTensor, alphaTensor;
|
Value otherTensor, alphaTensor;
|
||||||
|
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), otherScalar,
|
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
||||||
otherTensor, selfTy.getElementType())))
|
selfTy.getElementType(), {})))
|
||||||
return op.emitError("Currently only scalar constants are supported for "
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA Rsub operation");
|
"conversion in TOSA Rsub operation");
|
||||||
|
|
||||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
||||||
alphaTensor, selfTy.getElementType(),
|
alphaTensor, selfTy.getElementType(),
|
||||||
true)))
|
/*checkForUnity=*/true)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
|
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
|
||||||
alphaTensor, /*shift*/ 0);
|
alphaTensor, /*shift=*/0);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(
|
rewriter.replaceOpWithNewOp<tosa::SubOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), otherTensor,
|
op, getTypeConverter()->convertType(op.getType()), otherTensor,
|
||||||
|
@ -2142,7 +2216,7 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
|
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
|
||||||
auto reshapeOp =
|
auto reshapeOp =
|
||||||
rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(),
|
rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.self(),
|
||||||
rewriter.getI64ArrayAttr(newShape));
|
rewriter.getI64ArrayAttr(newShape));
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
|
@ -2184,6 +2258,80 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
|
AtenLog2Op op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// Not a tensor type.
|
||||||
|
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfType)
|
||||||
|
return op.emitError("Only tensor types are currently supported");
|
||||||
|
|
||||||
|
// Constant value of ln2.
|
||||||
|
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
||||||
|
auto ln2Op =
|
||||||
|
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
|
||||||
|
.getValue();
|
||||||
|
auto rcpOp =
|
||||||
|
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
|
||||||
|
|
||||||
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
auto logOp =
|
||||||
|
rewriter.create<tosa::LogOp>(op.getLoc(), outType, adaptor.self());
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
|
||||||
|
/*shift=*/0);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||||
|
AtenThresholdOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
// Not a tensor type.
|
||||||
|
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||||
|
if (!selfType)
|
||||||
|
return op.emitError("Only tensor types are currently supported");
|
||||||
|
|
||||||
|
auto selfElemTy = selfType.getElementType();
|
||||||
|
if (!selfElemTy.isIntOrFloat())
|
||||||
|
return op.emitError(
|
||||||
|
"Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
|
// Integer types with width > 32 are not supported
|
||||||
|
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
|
||||||
|
if (selfIntType && selfIntType.getWidth() > 32) {
|
||||||
|
return op.emitError(
|
||||||
|
"Integer types with width greater than 32 are not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
|
||||||
|
Value threshold, value;
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.threshold(), threshold,
|
||||||
|
selfElemTy, constTypeShape)))
|
||||||
|
return op.emitError("Only scalar constant is supported for threshold");
|
||||||
|
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), value,
|
||||||
|
selfElemTy, constTypeShape)))
|
||||||
|
return op.emitError("Only scalar constant is supported for value");
|
||||||
|
|
||||||
|
// Threshold only clamps the upper values. tosa::ClampOp has the same
|
||||||
|
// value for both threshold and clamped value so cannot be used.
|
||||||
|
auto outType = getTypeConverter()->convertType(op.getType());
|
||||||
|
|
||||||
|
auto cmpOp = rewriter.create<tosa::GreaterOp>(
|
||||||
|
op.getLoc(),
|
||||||
|
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||||
|
adaptor.self(), threshold);
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp,
|
||||||
|
adaptor.self(), value);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename AtenOpT, typename TosaOpT>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
|
@ -2527,6 +2675,9 @@ public:
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
||||||
|
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
||||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||||
|
|
||||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||||
|
@ -2629,6 +2780,8 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenThresholdOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
|
|
@ -595,6 +595,23 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ne.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.logical_not"(%[[VAL_4]]) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||||
|
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||||
|
return %0 : !torch.vtensor<[?,?],i1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @forward(
|
// CHECK-LABEL: func @forward(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
|
||||||
|
@ -615,3 +632,36 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
|
||||||
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
|
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
|
||||||
return %1 : !torch.vtensor<[3,2,4],f32>
|
return %1 : !torch.vtensor<[3,2,4],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.bitwise_and.Tensor$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||||
|
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||||
|
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],si32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.log2$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.693147182> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.log"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_3]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue