mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add decomposition for prims.var and prims.sqrt op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1794/head
parent
b966733e04
commit
fd236b2c89
|
@ -83,6 +83,8 @@ TORCHDYNAMO_XFAIL_SET = {
|
||||||
# error: unsupported by backend contract: tensor with unknown rank
|
# error: unsupported by backend contract: tensor with unknown rank
|
||||||
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
# note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32>
|
||||||
"ElementwisePreluModule_basic",
|
"ElementwisePreluModule_basic",
|
||||||
|
# error: op lowering missing. Issue: https://github.com/llvm/torch-mlir/issues/1792
|
||||||
|
"StdCorrectionKeepDimModule_basic",
|
||||||
}
|
}
|
||||||
|
|
||||||
MHLO_PASS_SET = {
|
MHLO_PASS_SET = {
|
||||||
|
|
|
@ -10881,6 +10881,55 @@ def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_PrimsVarOp : Torch_Op<"prims.var", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `prims::var : (Tensor, int[]?, int, int?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$inp,
|
||||||
|
AnyTorchOptionalListOfTorchIntType:$dims,
|
||||||
|
Torch_IntType:$correction,
|
||||||
|
AnyTorchOptionalIntType:$output_dtype
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult PrimsVarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void PrimsVarOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_PrimsSqrtOp : Torch_Op<"prims.sqrt", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `prims::sqrt : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult PrimsSqrtOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void PrimsSqrtOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
|
|
|
@ -5768,6 +5768,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.prims.sqrt\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.neg\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.neg\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -6032,6 +6036,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.int, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %false = torch.constant.bool false\n"
|
||||||
|
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||||
|
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
|
||||||
|
" return %1 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.var.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.var.dim\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
" %0 = torch.derefine %none : !torch.none to !torch.any\n"
|
||||||
|
|
|
@ -3425,6 +3425,38 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `prims.var` op into `aten.var.correction` op.
|
||||||
|
class DecomposePrimsVarOp : public OpRewritePattern<PrimsVarOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(PrimsVarOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (!op.getOutputDtype().getType().isa<Torch::NoneType>())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unimplemented non-None dtype for prims::var op");
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenVarCorrectionOp>(
|
||||||
|
op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(),
|
||||||
|
/*keepdim=*/cstFalse);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `prims.sqrt` op into `aten.sqrt` op.
|
||||||
|
class DecomposePrimsSqrtOp : public OpRewritePattern<PrimsSqrtOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(PrimsSqrtOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<AtenSqrtOp>(op, op.getType(), op.getSelf());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// The op is decomposed using the Box-Muller transform.
|
// The op is decomposed using the Box-Muller transform.
|
||||||
// Refer: https://en.wikipedia.org/wiki/Box-Muller_transform
|
// Refer: https://en.wikipedia.org/wiki/Box-Muller_transform
|
||||||
|
@ -3659,6 +3691,8 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandintLowOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanCorrectionOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposePrimsConvertElementTypeOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposePrimsVarOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposePrimsSqrtOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
|
||||||
|
|
|
@ -438,6 +438,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenRandintLowOp>();
|
target.addIllegalOp<AtenRandintLowOp>();
|
||||||
target.addIllegalOp<AtenVarMeanCorrectionOp>();
|
target.addIllegalOp<AtenVarMeanCorrectionOp>();
|
||||||
target.addIllegalOp<PrimsConvertElementTypeOp>();
|
target.addIllegalOp<PrimsConvertElementTypeOp>();
|
||||||
|
target.addIllegalOp<PrimsVarOp>();
|
||||||
|
target.addIllegalOp<PrimsSqrtOp>();
|
||||||
target.addIllegalOp<AtenRandnOp>();
|
target.addIllegalOp<AtenRandnOp>();
|
||||||
target.addIllegalOp<AtenRandnGeneratorOp>();
|
target.addIllegalOp<AtenRandnGeneratorOp>();
|
||||||
target.addIllegalOp<AtenVarMeanOp>();
|
target.addIllegalOp<AtenVarMeanOp>();
|
||||||
|
|
|
@ -675,7 +675,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
|
||||||
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
|
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
|
||||||
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp>(op)) {
|
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
|
||||||
|
PrimsSqrtOp>(op)) {
|
||||||
ValueKnowledge knowledge =
|
ValueKnowledge knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
Type dtype = operands[0]->getValue().dtype;
|
Type dtype = operands[0]->getValue().dtype;
|
||||||
|
@ -978,7 +979,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
visitReductionAlongAllDimsOp(op, dtype, operands);
|
visitReductionAlongAllDimsOp(op, dtype, operands);
|
||||||
return;
|
return;
|
||||||
} else if (isa<AtenStdOp, AtenStdDimOp, AtenStdCorrectionOp, AtenVarOp,
|
} else if (isa<AtenStdOp, AtenStdDimOp, AtenStdCorrectionOp, AtenVarOp,
|
||||||
AtenVarDimOp, AtenVarCorrectionOp>(op)) {
|
AtenVarDimOp, AtenVarCorrectionOp, PrimsVarOp>(op)) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
visitReductionAlongAllDimsOp(op, input.dtype, operands);
|
visitReductionAlongAllDimsOp(op, input.dtype, operands);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -107,6 +107,9 @@ def aten〇hardtanh〡shape(self: List[int], min_val: float = -1, max_val: float
|
||||||
def aten〇sqrt〡shape(self: List[int]) -> List[int]:
|
def aten〇sqrt〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def prims〇sqrt〡shape(self: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇neg〡shape(self: List[int]) -> List[int]:
|
def aten〇neg〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -307,6 +310,9 @@ def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[in
|
||||||
def aten〇var〡shape(self: List[int], unbiased: bool = True) -> List[int]:
|
def aten〇var〡shape(self: List[int], unbiased: bool = True) -> List[int]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: int, output_dtype: Optional[int] = None) -> List[int]:
|
||||||
|
return upstream_shape_functions.sum_mean_dim(inp, dims, False, None)
|
||||||
|
|
||||||
def aten〇var〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
def aten〇var〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]:
|
||||||
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None)
|
||||||
|
|
||||||
|
|
|
@ -674,6 +674,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
|
|
||||||
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
|
emit("prims::convert_element_type : (Tensor, int) -> (Tensor)")
|
||||||
|
emit("prims::var : (Tensor, int[]?, int, int?) -> (Tensor)")
|
||||||
|
emit("prims::sqrt : (Tensor) -> (Tensor)")
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# `quantized::` namespace.
|
# `quantized::` namespace.
|
||||||
|
|
Loading…
Reference in New Issue