mirror of https://github.com/llvm/torch-mlir
[tosa] Support for Aten[Unsqueeze|Contiguous|Dropout|Reshape|View] ops (#700)
parent
6b637a9fd9
commit
5d7a6c2976
|
@ -144,4 +144,12 @@ TOSA_PASS_SET = {
|
|||
"SiluModule_basic",
|
||||
"DropoutEvalIntModule_basic",
|
||||
"DropoutEvalFloatModule_basic",
|
||||
"ContiguousModule_basic",
|
||||
"DropoutModule_basic",
|
||||
"ViewExpandModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
"UnsafeViewExpandModule_basic",
|
||||
"ReshapeCollapseModule_basic",
|
||||
}
|
||||
|
|
|
@ -2329,6 +2329,116 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||
AtenUnsqueezeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType) {
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
}
|
||||
|
||||
auto selfRank = selfType.getRank();
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
if (!selfElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
int64_t dim;
|
||||
if (!matchPattern(op.dim(), m_TorchConstantInt(&dim)))
|
||||
return op->emitError("dim must be a Scalar constant");
|
||||
|
||||
dim = toPositiveDim(dim, selfRank);
|
||||
if (!isValidDim(dim, selfRank))
|
||||
return op.emitError("dim is statically invalid");
|
||||
|
||||
SmallVector<int64_t> outShape;
|
||||
for (auto en : llvm::enumerate(selfType.getShape())) {
|
||||
if (static_cast<int64_t>(en.index()) == dim)
|
||||
outShape.push_back(1);
|
||||
|
||||
outShape.push_back(en.value());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||
rewriter.getI64ArrayAttr(outShape));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||
AtenContiguousOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
// FIXME: memory_format is not handled.
|
||||
|
||||
rewriter.replaceOp(op, adaptor.self());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenDropoutOp>::matchAndRewrite(
|
||||
AtenDropoutOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.input().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
// FIXME: train and p are not handled.
|
||||
|
||||
bool train;
|
||||
if (!matchPattern(op.train(), m_TorchConstantBool(&train)))
|
||||
op.emitError("train must be a Scalar constant");
|
||||
|
||||
if (train)
|
||||
op.emitError("train must be false");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::CastOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.input());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
||||
AtenViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
if (!selfElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> outShape;
|
||||
if (!matchPattern(op.size(), m_TorchConstantIntList(outShape)))
|
||||
return op.emitError("size must consist of Scalar constants");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), adaptor.self(),
|
||||
rewriter.getI64ArrayAttr(outShape));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -2879,6 +2989,10 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenUnsqueezeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDropoutOp);
|
||||
INSERT_ATENOP_PATTERN(AtenViewOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -492,8 +492,8 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @forward(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
|
||||
// CHECK-LABEL: func @torch.aten.native_batch_norm$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,4,3],f32>) -> !torch.vtensor<[10,4,3],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,4,3],f32> -> tensor<10x4x3xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
|
@ -515,7 +515,7 @@ func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.v
|
|||
// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32>
|
||||
// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32>
|
||||
// CHECK: }
|
||||
func @forward(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
|
||||
func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> {
|
||||
%0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%1 = torch.vtensor.literal(dense<[3.000000e+00, 2.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
|
||||
%float1.000000e-01 = torch.constant.float 1.000000e-01
|
||||
|
@ -689,6 +689,38 @@ func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.unsqueeze$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],si32>) -> !torch.vtensor<[4,1,3],si32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],si32> -> tensor<4x3xi32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [4, 1, 3]} : (tensor<4x3xi32>) -> tensor<4x1x3xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x1x3xi32> -> !torch.vtensor<[4,1,3],si32>
|
||||
// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,1,3],si32>
|
||||
// CHECK: }
|
||||
|
||||
func @torch.aten.unsqueeze$basic(%arg0: !torch.vtensor<[4,3],si32> ) -> !torch.vtensor<[4,1,3],si32> {
|
||||
%int1 = torch.constant.int 1
|
||||
%0 = torch.aten.unsqueeze %arg0, %int1 : !torch.vtensor<[4,3],si32>, !torch.int -> !torch.vtensor<[4,1,3],si32>
|
||||
return %0 : !torch.vtensor<[4,1,3],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.contiguous$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
|
||||
// CHECK: %[[VAL_0:.*]] = torch.constant.int 4
|
||||
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
|
||||
|
@ -707,3 +739,21 @@ func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
|
|||
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4],f32>
|
||||
return %1 : !torch.vtensor<[3,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.dropout$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[VAL_3:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.cast"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
|
||||
%float0.000000e00 = torch.constant.float 0.000000e+00
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten.dropout %arg0, %float0.000000e00, %false : !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue