forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathivalue.cpp
81 lines (72 loc) · 2.28 KB
/
ivalue.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <ATen/core/ivalue.h>
#include <ATen/core/Formatting.h>
#include <cmath>
namespace c10 {
namespace ivalue {
CAFFE2_API c10::intrusive_ptr<ConstantString> ConstantString::create(
std::string str_) {
return c10::make_intrusive<ConstantString>(std::move(str_));
}
} // namespace ivalue
namespace {
template<typename List>
std::ostream& printList(std::ostream & out, const List &v,
const std::string start, const std::string finish) {
out << start;
for(size_t i = 0; i < v->elements().size(); ++i) {
if(i > 0)
out << ", ";
// make sure we use ivalue printing, and not default printing for the element type
out << IValue(v->elements()[i]);
}
out << finish;
return out;
}
} // anonymous namespace
std::ostream& operator<<(std::ostream & out, const IValue & v) {
switch(v.tag) {
case IValue::Tag::None:
return out << v.toNone();
case IValue::Tag::Tensor:
return out << v.toTensor();
case IValue::Tag::Double: {
double d = v.toDouble();
int c = std::fpclassify(d);
if (c == FP_NORMAL || c == FP_ZERO) {
int64_t i = int64_t(d);
if (double(i) == d) {
return out << i << ".";
}
}
auto orig_prec = out.precision();
return out
<< std::setprecision(std::numeric_limits<double>::max_digits10)
<< v.toDouble()
<< std::setprecision(orig_prec);
} case IValue::Tag::Int:
return out << v.toInt();
case IValue::Tag::Bool:
return out << (v.toBool() ? "True" : "False");
case IValue::Tag::Tuple:
return printList(out, v.toTuple(), "(", ")");
case IValue::Tag::IntList:
return printList(out, v.toIntList(), "[", "]");
case IValue::Tag::DoubleList:
return printList(out, v.toDoubleList(), "[", "]");
case IValue::Tag::BoolList:
return printList(out, v.toBoolList(), "[", "]");
case IValue::Tag::String:
return out << v.toStringRef();
case IValue::Tag::TensorList:
return printList(out, v.toTensorList(), "[", "]");
case IValue::Tag::Blob:
return out << v.toBlob();
case IValue::Tag::GenericList:
return printList(out, v.toGenericList(), "[", "]");
case IValue::Tag::Future:
return out << "Future";
}
AT_ERROR("Tag not found\n");
}
#undef TORCH_FORALL_TAGS
} // namespace c10