fix i64 torch.tensor dtype

tanyo/slice_scatter_stage
TanyoKwok 2022-11-22 11:44:18 +08:00
parent 3c85903c2f
commit 329d02061b
4 changed files with 101 additions and 48 deletions

View File

@ -1288,6 +1288,55 @@ public:
};
} // namespace
namespace {
class ConvertAtenSliceScatterOp
: public OpConversionPattern<AtenSliceScatterOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
auto input = adaptor.self();
RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}
Value src = adaptor.src();
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType =
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
Value result = rewriter.create<tensor::InsertSliceOp>(
loc, abstractSrc, input, offsets, resultShape, strides);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -1316,4 +1365,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
target.addIllegalOp<AtenCopyOp>();
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
}

View File

@ -2986,6 +2986,16 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSliceScatterOp op,
PatternRewriter &rewriter) const override {
int64_t inputRank = getTensorRank(op.self());
int64_t dimInt = 0;
if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
} else {
return rewriter.notifyMatchFailure(op, "dim must be constant");
}
auto getOptionalVal = [&](Value val, Value defVal) -> Value {
if (val.getType().isa<Torch::NoneType>()) {
return defVal;
@ -3005,15 +3015,31 @@ public:
Value start = getOptionalVal(op.start(), zero);
Value end = getOptionalVal(op.end(), dimSize);
Value step = getOptionalVal(op.step(), one);
// Step 1. calculate scatter indices mask
// Step 0. create indices
Type indicesType = ValueTensorType::get(
op.getContext(), ArrayRef<int64_t>{ShapedType::kDynamicSize},
IntegerType::get(op.getContext(), 64));
Type maskType = ValueTensorType::get(
op.getContext(), ArrayRef<int64_t>{ShapedType::kDynamicSize},
IntegerType::get(op.getContext(), 1));
auto indices = rewriter.create<AtenArangeOp>(
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
Value indices = rewriter.create<AtenArangeOp>(
op.getLoc(), indicesType, dimSize, none, none, none, none);
// Step 1. make indices broadcastable to self's shape
SmallVector<int64_t> newIndicesShapeInt(inputRank, 1);
SmallVector<Value> newIndicesShape(inputRank, one);
newIndicesShape[dimInt] = dimSize;
newIndicesShapeInt[dimInt] = ShapedType::kDynamicSize;
Value newIndicesSizeList = rewriter.create<PrimListConstructOp>(
op.getLoc(), ListType::get(IntType::get(op.getContext())),
newIndicesShape);
Type indicesDtype = indices.getType().cast<ValueTensorType>().getDtype();
Type newIndicesType = ValueTensorType::get(
op.getContext(), llvm::makeArrayRef(newIndicesShapeInt), indicesDtype);
indices = rewriter.create<AtenViewOp>(op.getLoc(), newIndicesType,
indices, newIndicesSizeList);
// Step 2. calculate scatter indices mask
Type maskType = ValueTensorType::get(
op.getContext(), newIndicesType.cast<ValueTensorType>().getSizes(),
IntegerType::get(op.getContext(), 1));
auto shiftIndices = rewriter.create<AtenSubScalarOp>(
op.getLoc(), indices.getType(), indices, start, one);
auto stepRemainder = rewriter.create<AtenRemainderScalarOp>(
@ -3023,36 +3049,12 @@ public:
auto maskStart = rewriter.create<AtenGeScalarOp>(op.getLoc(), maskType,
shiftIndices, zero);
auto maskEnd =
rewriter.create<AtenGeScalarOp>(op.getLoc(), maskType, indices, end);
rewriter.create<AtenLtScalarOp>(op.getLoc(), maskType, indices, end);
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
maskStart);
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
maskEnd);
int64_t inputRank = getTensorRank(op.self());
int64_t dimInt = 0;
if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
} else {
return rewriter.notifyMatchFailure(op, "dim must be constant");
}
// Step 2. make mask broadcastable to self's shape
SmallVector<int64_t> maskViewShapeInt(inputRank, 1);
SmallVector<Value> maskViewShape(inputRank, one);
maskViewShape[dimInt] = dimSize;
maskViewShapeInt[dimInt] = ShapedType::kDynamicSize;
Value maskViewSizeList = rewriter.create<PrimListConstructOp>(
op.getLoc(), ListType::get(IntType::get(op.getContext())),
maskViewShape);
Type maskDtype = mask.getType().cast<ValueTensorType>().getDtype();
Type maskViewType = ValueTensorType::get(
op.getContext(), llvm::makeArrayRef(maskViewShapeInt), maskDtype);
Value maskView = rewriter.create<AtenViewOp>(op.getLoc(), maskViewType,
mask, maskViewSizeList);
// Step 3. make src broadcastable to self's shape
Value src = op.src();
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
@ -3076,7 +3078,7 @@ public:
}
// Step 4. replace output = mask? src: self
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), maskView,
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), mask,
src, op.self());
return success();
}

View File

@ -240,7 +240,7 @@ class ExampleArgs:
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
OutputType.MHLO: [],
}

View File

@ -813,26 +813,26 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
// CHECK-NEXT: %[[INT0:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[INT1:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]]
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]]
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?,1],f32>
// CHECK-NEXT: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]]
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK-NEXT: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[INT1_4:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]]
// CHECK-NEXT: %[[T4:.*]] = torch.aten.sub.Scalar %[[T3]], %[[INT0]], %[[INT1_1]]
// CHECK-NEXT: %[[T5:.*]] = torch.aten.remainder.Scalar %[[T4]], %[[INT1_0]]
// CHECK-NEXT: %[[T6:.*]] = torch.aten.eq.Scalar %[[T5]], %[[INT0_2]]
// CHECK-NEXT: %[[T7:.*]] = torch.aten.ge.Scalar %[[T4]], %[[INT0_2]]
// CHECK-NEXT: %[[T8:.*]] = torch.aten.ge.Scalar %[[T3]], %[[T0]]
// CHECK-NEXT: %[[T9:.*]] = torch.aten.bitwise_and.Tensor %[[T6]], %[[T7]]
// CHECK-NEXT: %[[T10:.*]] = torch.aten.bitwise_and.Tensor %[[T9]], %[[T8]]
// CHECK-NEXT: %[[T11:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]]
// CHECK-NEXT: %[[T12:.*]] = torch.aten.view %[[T10]], %[[T11]]
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]]
// CHECK-NEXT: return %[[T13]]
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK-NEXT: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-NEXT: %[[T5:.*]] = torch.aten.view %[[T3]], %[[T4]] : !torch.vtensor<[?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>
// CHECK-NEXT: %[[T6:.*]] = torch.aten.sub.Scalar %[[T5]], %[[INT0]], %[[INT1_1]] : !torch.vtensor<[1,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
// CHECK-NEXT: %[[T7:.*]] = torch.aten.remainder.Scalar %[[T6]], %[[INT1_0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],si64>
// CHECK-NEXT: %[[T8:.*]] = torch.aten.eq.Scalar %[[T7]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T9:.*]] = torch.aten.ge.Scalar %[[T6]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T10:.*]] = torch.aten.lt.Scalar %[[T5]], %[[T0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T11:.*]] = torch.aten.bitwise_and.Tensor %[[T8]], %[[T9]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T12:.*]] = torch.aten.bitwise_and.Tensor %[[T11]], %[[T10]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK-NEXT: return %[[T13]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1