mirror of https://github.com/llvm/torch-mlir
fix i64 torch.tensor dtype
parent
3c85903c2f
commit
329d02061b
|
@ -1288,6 +1288,55 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenSliceScatterOp
|
||||
: public OpConversionPattern<AtenSliceScatterOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenSliceScatterOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
TypeConverter *typeConverter = getTypeConverter();
|
||||
|
||||
auto input = adaptor.self();
|
||||
|
||||
RankedTensorType resultType =
|
||||
typeConverter->convertType(op->getResult(0).getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
||||
SmallVector<Value> resultShape;
|
||||
SmallVector<Value> offsets;
|
||||
SmallVector<Value> strides;
|
||||
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
|
||||
AtenSliceScatterOpAdaptor>(
|
||||
op, adaptor, rewriter, resultShape, offsets, strides))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value src = adaptor.src();
|
||||
auto srcType = src.getType().cast<RankedTensorType>();
|
||||
int64_t srcRank = srcType.getRank();
|
||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||
auto abstractSrcType =
|
||||
RankedTensorType::get(srcAbstractSizes, srcType.getElementType());
|
||||
Value abstractSrc =
|
||||
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);
|
||||
|
||||
Value result = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, abstractSrc, input, offsets, resultShape, strides);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -1316,4 +1365,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
|
|||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenCopyOp>();
|
||||
patterns.add<ConvertAtenCopyOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSliceScatterOp>();
|
||||
patterns.add<ConvertAtenSliceScatterOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -2986,6 +2986,16 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSliceScatterOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
int64_t inputRank = getTensorRank(op.self());
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
}
|
||||
|
||||
auto getOptionalVal = [&](Value val, Value defVal) -> Value {
|
||||
if (val.getType().isa<Torch::NoneType>()) {
|
||||
return defVal;
|
||||
|
@ -3005,15 +3015,31 @@ public:
|
|||
Value start = getOptionalVal(op.start(), zero);
|
||||
Value end = getOptionalVal(op.end(), dimSize);
|
||||
Value step = getOptionalVal(op.step(), one);
|
||||
// Step 1. calculate scatter indices mask
|
||||
// Step 0. create indices
|
||||
Type indicesType = ValueTensorType::get(
|
||||
op.getContext(), ArrayRef<int64_t>{ShapedType::kDynamicSize},
|
||||
IntegerType::get(op.getContext(), 64));
|
||||
Type maskType = ValueTensorType::get(
|
||||
op.getContext(), ArrayRef<int64_t>{ShapedType::kDynamicSize},
|
||||
IntegerType::get(op.getContext(), 1));
|
||||
auto indices = rewriter.create<AtenArangeOp>(
|
||||
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
|
||||
Value indices = rewriter.create<AtenArangeOp>(
|
||||
op.getLoc(), indicesType, dimSize, none, none, none, none);
|
||||
|
||||
// Step 1. make indices broadcastable to self's shape
|
||||
SmallVector<int64_t> newIndicesShapeInt(inputRank, 1);
|
||||
SmallVector<Value> newIndicesShape(inputRank, one);
|
||||
newIndicesShape[dimInt] = dimSize;
|
||||
newIndicesShapeInt[dimInt] = ShapedType::kDynamicSize;
|
||||
Value newIndicesSizeList = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(), ListType::get(IntType::get(op.getContext())),
|
||||
newIndicesShape);
|
||||
Type indicesDtype = indices.getType().cast<ValueTensorType>().getDtype();
|
||||
Type newIndicesType = ValueTensorType::get(
|
||||
op.getContext(), llvm::makeArrayRef(newIndicesShapeInt), indicesDtype);
|
||||
indices = rewriter.create<AtenViewOp>(op.getLoc(), newIndicesType,
|
||||
indices, newIndicesSizeList);
|
||||
|
||||
// Step 2. calculate scatter indices mask
|
||||
Type maskType = ValueTensorType::get(
|
||||
op.getContext(), newIndicesType.cast<ValueTensorType>().getSizes(),
|
||||
IntegerType::get(op.getContext(), 1));
|
||||
auto shiftIndices = rewriter.create<AtenSubScalarOp>(
|
||||
op.getLoc(), indices.getType(), indices, start, one);
|
||||
auto stepRemainder = rewriter.create<AtenRemainderScalarOp>(
|
||||
|
@ -3023,36 +3049,12 @@ public:
|
|||
auto maskStart = rewriter.create<AtenGeScalarOp>(op.getLoc(), maskType,
|
||||
shiftIndices, zero);
|
||||
auto maskEnd =
|
||||
rewriter.create<AtenGeScalarOp>(op.getLoc(), maskType, indices, end);
|
||||
rewriter.create<AtenLtScalarOp>(op.getLoc(), maskType, indices, end);
|
||||
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
|
||||
maskStart);
|
||||
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
|
||||
maskEnd);
|
||||
|
||||
int64_t inputRank = getTensorRank(op.self());
|
||||
int64_t dimInt = 0;
|
||||
if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||
}
|
||||
|
||||
// Step 2. make mask broadcastable to self's shape
|
||||
SmallVector<int64_t> maskViewShapeInt(inputRank, 1);
|
||||
SmallVector<Value> maskViewShape(inputRank, one);
|
||||
maskViewShape[dimInt] = dimSize;
|
||||
maskViewShapeInt[dimInt] = ShapedType::kDynamicSize;
|
||||
Value maskViewSizeList = rewriter.create<PrimListConstructOp>(
|
||||
op.getLoc(), ListType::get(IntType::get(op.getContext())),
|
||||
maskViewShape);
|
||||
Type maskDtype = mask.getType().cast<ValueTensorType>().getDtype();
|
||||
Type maskViewType = ValueTensorType::get(
|
||||
op.getContext(), llvm::makeArrayRef(maskViewShapeInt), maskDtype);
|
||||
Value maskView = rewriter.create<AtenViewOp>(op.getLoc(), maskViewType,
|
||||
mask, maskViewSizeList);
|
||||
|
||||
// Step 3. make src broadcastable to self's shape
|
||||
Value src = op.src();
|
||||
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
|
||||
|
@ -3076,7 +3078,7 @@ public:
|
|||
}
|
||||
|
||||
// Step 4. replace output = mask? src: self
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), maskView,
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), mask,
|
||||
src, op.self());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -240,7 +240,7 @@ class ExampleArgs:
|
|||
# compiler where each backend can "own" its set of legal ops.
|
||||
BACKEND_LEGAL_OPS = {
|
||||
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ],
|
||||
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
|
||||
OutputType.MHLO: [],
|
||||
}
|
||||
|
||||
|
|
|
@ -813,26 +813,26 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
|
|||
// CHECK-NEXT: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[INT1_0:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]]
|
||||
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]]
|
||||
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]] : !torch.int, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?,1],f32>
|
||||
// CHECK-NEXT: %[[INT1_1:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[INT0_2:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]]
|
||||
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK-NEXT: %[[INT0_3:.*]] = torch.constant.int 0
|
||||
// CHECK-NEXT: %[[INT1_4:.*]] = torch.constant.int 1
|
||||
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]]
|
||||
// CHECK-NEXT: %[[T4:.*]] = torch.aten.sub.Scalar %[[T3]], %[[INT0]], %[[INT1_1]]
|
||||
// CHECK-NEXT: %[[T5:.*]] = torch.aten.remainder.Scalar %[[T4]], %[[INT1_0]]
|
||||
// CHECK-NEXT: %[[T6:.*]] = torch.aten.eq.Scalar %[[T5]], %[[INT0_2]]
|
||||
// CHECK-NEXT: %[[T7:.*]] = torch.aten.ge.Scalar %[[T4]], %[[INT0_2]]
|
||||
// CHECK-NEXT: %[[T8:.*]] = torch.aten.ge.Scalar %[[T3]], %[[T0]]
|
||||
// CHECK-NEXT: %[[T9:.*]] = torch.aten.bitwise_and.Tensor %[[T6]], %[[T7]]
|
||||
// CHECK-NEXT: %[[T10:.*]] = torch.aten.bitwise_and.Tensor %[[T9]], %[[T8]]
|
||||
// CHECK-NEXT: %[[T11:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]]
|
||||
// CHECK-NEXT: %[[T12:.*]] = torch.aten.view %[[T10]], %[[T11]]
|
||||
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]]
|
||||
// CHECK-NEXT: return %[[T13]]
|
||||
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
|
||||
// CHECK-NEXT: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
// CHECK-NEXT: %[[T5:.*]] = torch.aten.view %[[T3]], %[[T4]] : !torch.vtensor<[?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>
|
||||
// CHECK-NEXT: %[[T6:.*]] = torch.aten.sub.Scalar %[[T5]], %[[INT0]], %[[INT1_1]] : !torch.vtensor<[1,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
|
||||
// CHECK-NEXT: %[[T7:.*]] = torch.aten.remainder.Scalar %[[T6]], %[[INT1_0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],si64>
|
||||
// CHECK-NEXT: %[[T8:.*]] = torch.aten.eq.Scalar %[[T7]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T9:.*]] = torch.aten.ge.Scalar %[[T6]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T10:.*]] = torch.aten.lt.Scalar %[[T5]], %[[T0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T11:.*]] = torch.aten.bitwise_and.Tensor %[[T8]], %[[T9]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T12:.*]] = torch.aten.bitwise_and.Tensor %[[T11]], %[[T10]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
|
||||
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK-NEXT: return %[[T13]] : !torch.vtensor<[?,?],f32>
|
||||
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
|
|
Loading…
Reference in New Issue