mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for AtenRsubScalarOp for scalar constants (#531)
* [tosa] Support for AtenCeilOp and AtenReciprocalOp * [tosa] Support for comparator ops, Aten[Gt|Lt|Eq][Tensor|Scalar]Op with scalar constant * [tosa] Support for Scalar variants of Aten[Mul|Div|Add|Sub] Ops with scalar constants Signed-off-by: Anup Gangwar <anup.gangwar@arm.com> Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>pull/536/head snapshot-20220120.218
parent
c8ee8d02eb
commit
f8080bd1c5
|
@ -54,4 +54,31 @@ TOSA_PASS_SET = {
|
|||
"BmmModule_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_3d",
|
||||
"RsubModule_basic",
|
||||
"RsubModule_noalpha_basic",
|
||||
"ElementwiseGtFloatScalarModule_basic",
|
||||
"ElementwiseGtIntScalarModule_basic",
|
||||
"ElementwiseGtMixed2ScalarModule_basic",
|
||||
"ElementwiseGtFloatTensorModule_basic",
|
||||
"ElementwiseGtIntTensorModule_basic",
|
||||
"ElementwiseLtFloatScalarModule_basic",
|
||||
"ElementwiseLtIntScalarModule_basic",
|
||||
"ElementwiseLtDiffWidthScalarModule_basic",
|
||||
"ElementwiseLtFloatTensorModule_basic",
|
||||
"ElementwiseLtIntTensorModule_basic",
|
||||
"ElementwiseEqFloatScalarModule_basic",
|
||||
"ElementwiseEqIntScalarModule_basic",
|
||||
"ElementwiseEqDiffWidthScalarModule_basic",
|
||||
"ElementwiseEqFloatTensorModule_basic",
|
||||
"ElementwiseEqIntTensorModule_basic",
|
||||
"ElementwiseMulScalarModule_int",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"ElementwiseMulTensorIntModule_basic",
|
||||
"ElementwiseDivScalarModule_basic",
|
||||
"ElementwiseSubScalarFloatModule_basic",
|
||||
"ElementwiseAddScalarFloatModule_basic",
|
||||
"ElementwiseMulScalarModule_float",
|
||||
"ElementwiseCeilModule_basic",
|
||||
"ElementwiseReciprocalModule_basic",
|
||||
"TypePromotionAlphaWiderModule_basic",
|
||||
}
|
||||
|
|
|
@ -110,6 +110,67 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
// FIXME: This will eventually go into a Tosa*Utils file.
|
||||
LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value torchScalarValue,
|
||||
Value &tosaTensor, Type dtype) {
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
double scalarValue;
|
||||
|
||||
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
|
||||
return failure();
|
||||
|
||||
tosaTensor =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
|
||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
int64_t scalarValue;
|
||||
|
||||
if (!matchPattern(torchScalarValue, m_TorchConstantInt(&scalarValue)))
|
||||
return failure();
|
||||
|
||||
auto w = intType.getWidth();
|
||||
if (w != 32 && w != 64)
|
||||
return op->emitError("Unsupported integer type") << intType;
|
||||
|
||||
if (w == 32) {
|
||||
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, {static_cast<int32_t>(scalarValue)}, {})
|
||||
.getValue();
|
||||
} else if (w == 64) {
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {scalarValue}, {})
|
||||
.getValue();
|
||||
}
|
||||
return success();
|
||||
} else
|
||||
return op->emitError("Usupported element type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value alphaScalar,
|
||||
Value &alphaTensor, Type dtype,
|
||||
bool checkForUnity) {
|
||||
if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor,
|
||||
dtype)))
|
||||
return success();
|
||||
|
||||
// `alpha` has not been specified.
|
||||
int64_t alphaValue;
|
||||
if (!matchPattern(alphaScalar, m_TorchConstantInt(&alphaValue)))
|
||||
return op->emitError("Currently only scalar constants are supported for "
|
||||
"alpha in TOSA operation");
|
||||
// When no alpha has been specified, this must be 1.
|
||||
if (checkForUnity && alphaValue != 1)
|
||||
return op->emitError("Unsupported integer value for alpha");
|
||||
|
||||
alphaTensor =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, alphaValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// These binary op legalizations are specific to add/sub which have an
|
||||
// alpha multiplier.
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
|
@ -121,34 +182,191 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
if (!lhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("Add: input datatypes mismatched");
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
|
||||
// FIXME: Handle alpha.
|
||||
// Needs extraction of floating point constant.
|
||||
// Handle alpha.
|
||||
Value alphaTensor;
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(),
|
||||
alphaTensor, lhsElemTy, false)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"alpha in conversion to TOSA operation");
|
||||
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
||||
rhsTensor, alphaTensor, /*shift*/ 0);
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhs);
|
||||
lhs, multTensor);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization supported");
|
||||
}
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
// Binary op legalizations for comparator ops.
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
// There is no Lesser operator in TOSA
|
||||
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
(swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Binary op legalizations for Mul variants.
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenMulOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>() ||
|
||||
lhsElemTy.isa<mlir::IntegerType>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhsTensor,
|
||||
/*shift=*/0);
|
||||
return success();
|
||||
} else {
|
||||
// Quantized multiplication may need to rescale inputs.
|
||||
return op.emitError("Only floating-point or integer datatype "
|
||||
"legalization currently supported");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AtenOpT>
|
||||
class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op->getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
||||
rhsTensor);
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rcpOp.getResult(), /*shift=*/0);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tosa::DivOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, rhsTensor);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// This defines a template to construct ops whose legalizations are
|
||||
|
@ -227,69 +445,6 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
|||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenMulTensorOp>::matchAndRewrite(
|
||||
AtenMulTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("Add: input datatypes mismatched");
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs, rhs,
|
||||
/*shift=*/0);
|
||||
return success();
|
||||
} else {
|
||||
// Quantized multiplication may need to rescale inputs.
|
||||
return op.emitError(
|
||||
"Only floating-point datatype legalization currently supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenDivTensorOp>::matchAndRewrite(
|
||||
AtenDivTensorOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
||||
|
||||
if (!lhsTy || !rhsTy)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
auto rhsElemTy = rhsTy.getElementType();
|
||||
|
||||
if (lhsElemTy != rhsElemTy)
|
||||
return op.emitError("Add: input datatypes mismatched");
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
auto rcpOp = rewriter.create<tosa::ReciprocalOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), rhs);
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs,
|
||||
rcpOp.getResult(), /*shift=*/0);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tosa::DivOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), lhs, rhs);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
using ReductionConvFunc = llvm::Optional<Value> (*)(PatternRewriter &,
|
||||
Operation *,
|
||||
RankedTensorType, Value,
|
||||
|
@ -635,23 +790,6 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
|
|||
}
|
||||
};
|
||||
|
||||
// FIXME(AG): This will eventually go into a Tosa*Utils file
|
||||
// Convert an fp32 scalar into tosa fp32 tensor.
|
||||
static LogicalResult
|
||||
tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value torchScalarValue, Value &tosaTensor) {
|
||||
double scalarValue;
|
||||
|
||||
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
|
||||
return failure();
|
||||
|
||||
// Construct a tosa.const
|
||||
tosaTensor =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||
|
@ -668,8 +806,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
|
||||
Value expTensor;
|
||||
Value expScalar = op.exponent();
|
||||
if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar,
|
||||
expTensor)))
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), expScalar,
|
||||
expTensor, selfTy.getElementType())))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
|
||||
|
@ -1238,6 +1376,45 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
||||
AtenRsubScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
auto self = adaptor.self();
|
||||
auto otherScalar = op.other();
|
||||
auto alphaScalar = op.alpha();
|
||||
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
if (!selfTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA Rsub");
|
||||
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
||||
return op.emitError("Only floating-point datatype legalization supported");
|
||||
|
||||
Value otherTensor, alphaTensor;
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), otherScalar,
|
||||
otherTensor, selfTy.getElementType())))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Rsub operation");
|
||||
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
||||
alphaTensor, selfTy.getElementType(),
|
||||
true)))
|
||||
return failure();
|
||||
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
|
||||
alphaTensor, /*shift*/ 0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), otherTensor,
|
||||
multTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -1281,6 +1458,8 @@ public:
|
|||
INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp)
|
||||
INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp)
|
||||
INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp)
|
||||
INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp)
|
||||
INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp)
|
||||
#undef INSERT_UNARY_PATTERN
|
||||
|
||||
#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \
|
||||
|
@ -1294,9 +1473,36 @@ public:
|
|||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenAddSubOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp)
|
||||
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp)
|
||||
#undef INSERT_BINARY_ADDSUB_PATTERN
|
||||
|
||||
#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenCompareOp<AtenOp, TosaOp>>(typeConverter, context);
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMulOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp);
|
||||
INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp);
|
||||
#undef INSERT_BINARY_MUL_PATTERN
|
||||
|
||||
#define INSERT_BINARY_DIV_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenDivOp<AtenOp>>(typeConverter, context);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp);
|
||||
INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp);
|
||||
#undef INSERT_BINARY_DIV_PATTERN
|
||||
|
||||
#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenMultipleDimsReductionOp<AtenOp, ConversionFunc>>( \
|
||||
|
@ -1348,10 +1554,9 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenTanhOp);
|
||||
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -141,5 +141,9 @@ template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
|||
ArrayRef<int32_t> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
|
||||
template llvm::Optional<Value> getConstTensor<int64_t>(PatternRewriter &,
|
||||
Operation *,
|
||||
ArrayRef<int64_t> vec,
|
||||
ArrayRef<int64_t> shape);
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
|
|
@ -105,15 +105,46 @@ func @torch.aten.bitwise_not$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ceil$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.ceil"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.ceil$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.ceil %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.reciprocal$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.reciprocal"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.add$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.add"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.add"(%[[VAL_2]], %[[VAL_6]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
|
||||
|
@ -123,14 +154,17 @@ func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vten
|
|||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.sub$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ARG1_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[ARG2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.sub"(%[[ARG0_BUILTIN]], %[[ARG1_BUILTIN]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_3]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_2]], %[[VAL_6]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.int -> !torch.vtensor<[?, ?],f32>
|
||||
|
@ -347,3 +381,93 @@ func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !t
|
|||
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<6.432100e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%other = torch.constant.float 3.123400e+00
|
||||
%alpha = torch.constant.float 6.432100e+00
|
||||
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.rsub.Scalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.mul"(%[[VAL_1]], %[[VAL_5]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.sub"(%[[VAL_4]], %[[VAL_6]]) : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%other = torch.constant.float 3.123400e+00
|
||||
%alpha = torch.constant.int 1
|
||||
%0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.gt.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.lt.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.greater"(%[[VAL_3]], %[[VAL_2]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.eq.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue