mirror of https://github.com/llvm/torch-mlir
[TORCH] The torch definition related to aten.gelu has changed.
New str argument approximation is added.pull/602/head
parent
ed9bd556b3
commit
abbde7d439
|
@ -1342,14 +1342,15 @@ def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [
|
|||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gelu : (Tensor) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::gelu : (Tensor, str) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_StringType:$approximate
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
let assemblyFormat = "$self `,` $approximate attr-dict `:` qualified(type($self)) `,` qualified(type($approximate)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [
|
||||
|
@ -1383,6 +1384,20 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
|
|||
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($threshold)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
@ -2103,20 +2118,6 @@ def Torch_AtenSizeOp : Torch_Op<"aten.size", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $value attr-dict `:` qualified(type($self)) `,` qualified(type($value)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
@ -3830,15 +3831,16 @@ def Torch_AtenGeluBackwardOp : Torch_Op<"aten.gelu_backward", [
|
|||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::gelu_backward : (Tensor, Tensor) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad,
|
||||
AnyTorchTensorType:$self
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_StringType:$approximate
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$grad `,` $self attr-dict `:` qualified(type($grad)) `,` qualified(type($self)) `->` qualified(type($result))";
|
||||
let assemblyFormat = "$grad_output `,` $self `,` $approximate attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($approximate)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_Aten_LogSoftmaxBackwardDataOp : Torch_Op<"aten._log_softmax_backward_data", [
|
||||
|
|
|
@ -60,6 +60,21 @@ struct torch_constant_float_op_binder {
|
|||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct torch_constant_str_op_binder {
|
||||
std::string &bind_value;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bv if match succeeds.
|
||||
torch_constant_str_op_binder(std::string &bv) : bind_value(bv) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
if (auto constantString = dyn_cast<Torch::ConstantStrOp>(op)) {
|
||||
bind_value = constantString.value().str();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Matches the integer stored in a `torch.constant.bool`.
|
||||
|
@ -74,6 +89,12 @@ m_TorchConstantFloat(double *bind_value) {
|
|||
return detail::torch_constant_float_op_binder(bind_value);
|
||||
}
|
||||
|
||||
/// Matches the string value stored in a `torch.constant.str`.
|
||||
inline detail::torch_constant_str_op_binder
|
||||
m_TorchConstantStr(std::string &bind_value) {
|
||||
return detail::torch_constant_str_op_binder(bind_value);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the bool stored in a `torch.constant.bool`.
|
||||
struct torch_constant_bool_op_binder {
|
||||
|
|
|
@ -1630,6 +1630,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
gelu.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
// TODO: Take approximation into account.
|
||||
std::string approximate;
|
||||
if (!matchPattern(gelu.approximate(), m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none")
|
||||
return nullptr;
|
||||
Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]);
|
||||
return b.create<arith::MulFOp>(loc, payloadArgs[0], cdf);
|
||||
}
|
||||
|
@ -1641,6 +1646,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
geluBackward.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
// TODO: Take approximation into account.
|
||||
std::string approximate;
|
||||
if (!matchPattern(geluBackward.approximate(),
|
||||
m_TorchConstantStr(approximate)) ||
|
||||
approximate != "none")
|
||||
return nullptr;
|
||||
Type elementType = payloadArgs[1].getType();
|
||||
Value cstAlpha0 = b.create<arith::ConstantOp>(
|
||||
loc, FloatAttr::get(elementType, 1.12837916709551257390));
|
||||
|
|
|
@ -497,7 +497,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
emit("aten::gelu : (Tensor) -> (Tensor)")
|
||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
|
||||
|
@ -693,7 +693,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
# backprop ops
|
||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::gelu_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::gelu_backward : (Tensor, Tensor, str) -> (Tensor)")
|
||||
emit("aten::_log_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue